https://github.com/soran-ghaderi/torchebm
🍓 Build and train energy-based and diffusion models in PyTorch ⚡.
https://github.com/soran-ghaderi/torchebm
contrastive-divergence cuda diffusion-models energy-based-model generative-ai hamiltonian hamiltonian-monte-carlo langevin-dynamics noise-contrastive-estimation probabilistic-machine-learning reasoning sampling-methods score-matching variational-inference
Last synced: 10 days ago
JSON representation
🍓 Build and train energy-based and diffusion models in PyTorch ⚡.
- Host: GitHub
- URL: https://github.com/soran-ghaderi/torchebm
- Owner: soran-ghaderi
- License: mit
- Created: 2024-10-05T12:41:38.000Z (over 1 year ago)
- Default Branch: master
- Last Pushed: 2026-02-11T08:02:25.000Z (16 days ago)
- Last Synced: 2026-02-11T13:42:41.995Z (16 days ago)
- Topics: contrastive-divergence, cuda, diffusion-models, energy-based-model, generative-ai, hamiltonian, hamiltonian-monte-carlo, langevin-dynamics, noise-contrastive-estimation, probabilistic-machine-learning, reasoning, sampling-methods, score-matching, variational-inference
- Language: Python
- Homepage: https://soran-ghaderi.github.io/torchebm/
- Size: 69.1 MB
- Stars: 76
- Watchers: 4
- Forks: 7
- Open Issues: 14
-
Metadata Files:
- Readme: README.md
- Changelog: CHANGELOG.md
- License: LICENSE
- Code of conduct: CODE_OF_CONDUCT.md
Awesome Lists containing this project
README
⚡ Energy-Based Modeling library for PyTorch, offering tools for 🔬 sampling, 🧠 inference, and 📊 learning in complex distributions.

## What is ∇ TorchEBM 🍓?
**Energy-Based Models (EBMs)** offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or "energy") to each data point. Lower energy corresponds to higher probability.
**TorchEBM** simplifies working with EBMs in [PyTorch](https://pytorch.org/). It provides a suite of tools designed for researchers and practitioners, enabling efficient implementation and exploration of:
* **Defining complex energy functions:** Easily create custom energy landscapes using PyTorch modules.
* **Training:** Loss functions and procedures suitable for EBM parameter estimation including score matching and contrastive divergence variants.
* **Sampling:** Algorithms to draw samples from the learned distribution \( p(x) \).
## Documentation
For detailed documentation, including installation instructions, usage examples, and API references, please visit
the 📚 [TorchEBM Website](https://soran-ghaderi.github.io/torchebm/).
## Features
- **Core Components**:
- Energy functions: Standard energy landscapes (Gaussian, Double Well, Rosenbrock, etc.)
- Datasets: Data generators for training and evaluation
- Loss functions: Contrastive Divergence, Score Matching, and more
- Sampling algorithms: Langevin Dynamics, Hamiltonian Monte Carlo (HMC), and more
- Evaluation metrics: Diagnostics for sampling and training
- **Performance Optimizations**:
- CUDA-accelerated implementations
- Parallel sampling capabilities
- Extensive diagnostics
Gaussian Function
Double Well Function
Rastrigin Function
Rosenbrock Function
## Installation
```bash
pip install torchebm
```
#### Dependencies
- [PyTorch](https://pytorch.org/) (with CUDA support for optimal performance)
- Other dependencies are listed in [requirements.txt](requirements.txt)
## Usage Examples
### Common Setup
```python
import torch
from torchebm.core import GaussianEnergy, DoubleWellEnergy
# Set device for computation
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define dimensions
dim = 10
n_samples = 250
n_steps = 500
```
### Energy Function Examples
```python
# Create a multivariate Gaussian energy function
gaussian_energy = GaussianEnergy(
mean=torch.zeros(dim, device=device), # Center at origin
cov=torch.eye(dim, device=device) # Identity covariance (standard normal)
)
# Create a double well potential
double_well_energy = DoubleWellEnergy(barrier_height=2.0)
```
### 1. Training a simple EBM Over a Gaussian Mixture Using Langevin Dynamics Sampler
```python
import torch.optim as optim
from torch.utils.data import DataLoader
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
from torchebm.samplers import LangevinDynamics
# Define an NN energy model
class MLPEnergy(BaseEnergyFunction):
def __init__(self, input_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
return self.net(x).squeeze(-1) # a scalar value
energy_fn = MLPEnergy(input_dim=2).to(device)
sampler = LangevinDynamics(energy_function=energy_fn, step_size=0.01, device=device)
cd_loss_fn = ContrastiveDivergence(
energy_function=energy_fn,
sampler=sampler,
k_steps=10 # MCMC steps for negative samples gen
)
optimizer = optim.Adam(energy_fn.parameters(), lr=0.001)
mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123).get_data()
dataloader = DataLoader(mixture_dataset, batch_size=32, shuffle=True)
# Training Loop
for epoch in range(10):
epoch_loss = 0.0
for i, batch_data in enumerate(dataloader):
batch_data = batch_data.to(device)
optimizer.zero_grad()
loss, neg_samples = cd_loss(batch_data)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")
```
### 2. Hamiltonian Monte Carlo (HMC)
```python
from torchebm.samplers import HamiltonianMonteCarlo
# Define a 10-D Gaussian energy function
energy_fn = GaussianEnergy(mean=torch.zeros(10), cov=torch.eye(10))
# Initialize HMC sampler
hmc_sampler = HamiltonianMonteCarlo(
energy_function=energy_fn, step_size=0.1, n_leapfrog_steps=10, device=device
)
# Sample 10,000 points in 10 dimensions
final_samples = hmc_sampler.sample(
dim=10, n_steps=500, n_samples=10000, return_trajectory=False
)
print(final_samples.shape) # Result batch_shape: (10000, 10) - (n_samples, dim)
# Sample with diagnostics and trajectory
final_samples, diagnostics = hmc_sampler.sample(
n_samples=n_samples,
n_steps=n_steps,
dim=dim,
return_trajectory=True,
return_diagnostics=True,
)
print(final_samples.shape) # Trajectory batch_shape: (250, 500, 10) - (n_samples, k_steps, dim)
print(diagnostics.shape) # Diagnostics batch_shape: (500, 4, 250, 10) - (k_steps, 4, n_samples, dim)
# The diagnostics contain: Mean (dim=0), Variance (dim=1), Energy (dim=2), Acceptance rates (dim=3)
# Sample from a custom initialization
x_init = torch.randn(n_samples, dim, dtype=torch.float32, device=device)
samples = hmc_sampler.sample(x=x_init, n_steps=100)
print(samples.shape) # Result batch_shape: (250, 10) -> (n_samples, dim)
```
## Library Structure
```
torchebm/
├── core/ # Core functionality
│ ├── energy_function.py # Energy function definitions
│ ├── basesampler.py # Base sampler class
│ └── ...
├── samplers/ # Sampling algorithms
│ ├── langevin_dynamics.py # Langevin dynamics implementation
│ ├── mcmc.py # HMC implementation
│ └── ...
├── models/ # Neural network models
├── evaluation/ # Evaluation metrics and utilities
├── datasets/
│ └── generators.py # Data generators for training
├── losses/ # BaseLoss functions for training
├── utils/ # Utility functions
└── cuda/ # CUDA optimizations
```
## Visualization Examples
Langevin Dynamics Sampling
Single Langevin Dynamics Trajectory
Parallel Langevin Dynamics Sampling
Check out the `examples/` directory for sample scripts:
- `samplers/`: Demonstrates different sampling algorithms
- `datasets/`: Depicts data generation using built-in datasets
- `training_models/`: Shows how to train energy-based models using TorchEBM
- `visualization/`: Visualizes sampling results and trajectories
- and more!
## Contributing
Contributions are welcome! Step-by-step instructions for contributing to the project can be found on the [contributing.md](docs/developer_guide/contributing.md) page on the website.
Please check the issues page for current tasks or create a new issue to discuss proposed changes.
## Show your Support for ∇ TorchEBM 🍓
Please ⭐️ this repository if ∇ TorchEBM helped you and spread the word.
Thank you! 🚀
## Citation
If you use ∇ TorchEBM in your research, please cite it using the following BibTeX entry:
```bibtex
@misc{torchebm_library_2025,
author = {Ghaderi, Soran and Contributors},
title = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
year = {2025},
url = {https://github.com/soran-ghaderi/torchebm},
}
```
## Changelog
For a detailed list of changes between versions, please see our [CHANGELOG](CHANGELOG.md).
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Research Collaboration
If you are interested in collaborating on research projects (**diffusion**-/**flow**-/**energy-based** models) or have
any questions about the library, please feel free to reach out. I am open to discussions and collaborations that can
enhance the capabilities of **∇ TorchEBM** 🍓 and contribute to the field of generative modeling.