https://github.com/soran-ghaderi/torchebm
⚡ Energy-Based Modeling library for PyTorch, offering tools for sampling, inference, and learning in complex distributions.
https://github.com/soran-ghaderi/torchebm
contrastive-divergence cuda diffusion-models energy-based-model generative-ai langevin-dynamics noise-contrastive-estimation probabilistic-machine-learning reasoning sampling-methods score-matching variational-inference
Last synced: 5 months ago
JSON representation
⚡ Energy-Based Modeling library for PyTorch, offering tools for sampling, inference, and learning in complex distributions.
- Host: GitHub
- URL: https://github.com/soran-ghaderi/torchebm
- Owner: soran-ghaderi
- License: mit
- Created: 2024-10-05T12:41:38.000Z (12 months ago)
- Default Branch: master
- Last Pushed: 2025-04-22T15:52:01.000Z (6 months ago)
- Last Synced: 2025-04-23T04:19:38.483Z (5 months ago)
- Topics: contrastive-divergence, cuda, diffusion-models, energy-based-model, generative-ai, langevin-dynamics, noise-contrastive-estimation, probabilistic-machine-learning, reasoning, sampling-methods, score-matching, variational-inference
- Language: Python
- Homepage: https://soran-ghaderi.github.io/torchebm/
- Size: 32.2 MB
- Stars: 7
- Watchers: 1
- Forks: 1
- Open Issues: 10
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
![]()
⚡ Energy-Based Modeling library for PyTorch, offering tools for 🔬 sampling, 🧠 inference, and 📊 learning in complex distributions.

## What is TorchEBM?
**Energy-Based Models (EBMs)** offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or "energy") to each data point. Lower energy corresponds to higher probability.
**TorchEBM** simplifies working with EBMs in [PyTorch](https://pytorch.org/). It provides a suite of tools designed for researchers and practitioners, enabling efficient implementation and exploration of:
* **Defining complex energy functions:** Easily create custom energy landscapes using PyTorch modules.
* **Training:** Loss functions and procedures suitable for EBM parameter estimation including score matching and contrastive divergence variants.
* **Sampling:** Algorithms to draw samples from the learned distribution \( p(x) \).## Features
- **Core Components**:
- Energy functions: Standard energy landscapes (Gaussian, Double Well, Rosenbrock, etc.)
- Datasets: Data generators for training and evaluation
- Loss functions: Contrastive Divergence, Score Matching, and more
- Sampling algorithms: Langevin Dynamics, Hamiltonian Monte Carlo (HMC), and more
- Evaluation metrics: Diagnostics for sampling and training
- **Performance Optimizations**:
- CUDA-accelerated implementations
- Parallel sampling capabilities
- Extensive diagnostics
![]()
![]()
![]()
![]()
Gaussian Function
Double Well Function
Rastrigin Function
Rosenbrock Function
## Installation
```bash
pip install torchebm
```## Usage Examples
### Common Setup
```python
import torch
from torchebm.core import GaussianEnergy, DoubleWellEnergy# Set device for computation
device = "cuda" if torch.cuda.is_available() else "cpu"# Define dimensions
dim = 10
n_samples = 250
n_steps = 500
```### Energy Function Examples
```python
# Create a multivariate Gaussian energy function
gaussian_energy = GaussianEnergy(
mean=torch.zeros(dim, device=device), # Center at origin
cov=torch.eye(dim, device=device) # Identity covariance (standard normal)
)# Create a double well potential
double_well_energy = DoubleWellEnergy(barrier_height=2.0)
```### 1. Training a simple EBM Over a Gaussian Mixture Using Langevin Dynamics Sampler
```python
import torch.optim as optim
from torch.utils.data import DataLoaderfrom torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
from torchebm.samplers.langevin_dynamics import LangevinDynamics# Define a 10D Gaussian energy function
energy_fn = MLPEnergy(input_dim=2).to(device)
sampler = LangevinDynamics(energy_function=energy_fn, step_size=0.01, device=device)cd_loss_fn = ContrastiveDivergence(
energy_function=energy_fn,
sampler=sampler,
n_steps=10 # MCMC steps for negative samples gen
)optimizer = optim.Adam(energy_fn.parameters(), lr=0.001)
mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123).get_data()
dataloader = DataLoader(mixture_dataset, batch_size=32, shuffle=True)# Training Loop
for epoch in range(10):
epoch_loss = 0.0
for i, batch_data in enumerate(dataloader):
batch_data = batch_data.to(device)optimizer.zero_grad()
loss, neg_samples = cd_loss(batch_data)
loss.backward()
optimizer.step()epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.6f}")
```### 2. Hamiltonian Monte Carlo (HMC)
```python
from torchebm.samplers.hmc import HamiltonianMonteCarlo# Define a 10D Gaussian energy function
energy_fn = GaussianEnergy(mean=torch.zeros(10), cov=torch.eye(10))# Initialize HMC sampler
hmc_sampler = HamiltonianMonteCarlo(
energy_function=energy_fn, step_size=0.1, n_leapfrog_steps=10, device=device
)# Sample 10,000 points in 10 dimensions
final_samples = hmc_sampler.sample(
dim=10, n_steps=500, n_samples=10000, return_trajectory=False
)
print(final_samples.shape) # Result shape: (10000, 10) - (n_samples, dim)# Sample with diagnostics and trajectory
final_samples, diagnostics = hmc_sampler.sample(
n_samples=n_samples,
n_steps=n_steps,
dim=dim,
return_trajectory=True,
return_diagnostics=True,
)print(final_samples.shape) # Trajectory shape: (250, 500, 10) - (n_samples, n_steps, dim)
print(diagnostics.shape) # Diagnostics shape: (500, 4, 250, 10) - (n_steps, 4, n_samples, dim)
# The diagnostics contain: Mean (dim=0), Variance (dim=1), Energy (dim=2), Acceptance rates (dim=3)# Sample from a custom initialization
x_init = torch.randn(n_samples, dim, dtype=torch.float32, device=device)
samples = hmc_sampler.sample(x=x_init, n_steps=100)
print(samples.shape) # Result shape: (250, 10) -> (n_samples, dim)
```## Library Structure
```
torchebm/
├── core/ # Core functionality
│ ├── energy_function.py # Energy function definitions
│ ├── basesampler.py # Base sampler class
│ └── ...
├── samplers/ # Sampling algorithms
│ ├── langevin_dynamics.py # Langevin dynamics implementation
│ ├── mcmc.py # HMC implementation
│ └── ...
├── models/ # Neural network models
├── evaluation/ # Evaluation metrics and utilities
├── datasets/
│ └── generators.py # Data generators for training
├── losses/ # BaseLoss functions for training
├── utils/ # Utility functions
└── cuda/ # CUDA optimizations
```## Visualization Examples
![]()
![]()
![]()
Langevin Dynamics Sampling
Single Langevin Dynamics Trajectory
Parallel Langevin Dynamics Sampling
Check out the `examples/` directory for sample scripts:
- `langevin_dynamics_sampling.py`: Demonstrates Langevin dynamics sampling
- `hmc_examples.py`: Demonstrates Hamiltonian Monte Carlo sampling
- `energy_fn_visualization.py`: Visualizes various energy functions## Contributing
Contributions are welcome! Please check the issues page for current tasks or create a new issue to discuss proposed changes.
## License
This project is licensed under the MIT License - see the LICENSE file for details.