Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/vislearn/FFF
Free-form flows are a generative model training a pair of neural networks via maximum likelihood
https://github.com/vislearn/FFF
Last synced: 14 days ago
JSON representation
Free-form flows are a generative model training a pair of neural networks via maximum likelihood
- Host: GitHub
- URL: https://github.com/vislearn/FFF
- Owner: vislearn
- License: mit
- Created: 2023-10-25T13:30:13.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2024-06-12T12:41:28.000Z (5 months ago)
- Last Synced: 2024-08-01T16:45:43.449Z (3 months ago)
- Language: Jupyter Notebook
- Homepage:
- Size: 3.97 MB
- Stars: 35
- Watchers: 5
- Forks: 3
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Free-form flows
This is the official `PyTorch` implementation for our papers:
1. [Free-form flows: Make Any Architecture a Normalizing Flow](http://arxiv.org/abs/2310.16624) on full-dimensional normalizing flows:
```bibtex
@inproceedings{draxler2024freeform,
title = {{Free-form flows: Make Any Architecture a Normalizing Flow}},
author = {Draxler, Felix and Sorrenson, Peter and Zimmermann, Lea and Rousselot, Armand and Köthe, Ullrich},
booktitle = {International Conference on Artificial Intelligence and Statistics},
year = {2024}
}
```
2. [Lifting Architectural Constraints of Injective Flows](http://arxiv.org/abs/2306.01843) on learning a manifold and the distribution on it jointly:
```bibtex
@inproceedings{sorrenson2024lifting,
title = {{Lifting Architectural Constraints of Injective Flows}},
booktitle = {International {{Conference}} on {{Learning Representations}}},
author = {Sorrenson, Peter and Draxler, Felix and Rousselot, Armand and Hummerich, Sander and Zimmermann, Lea and Köthe, Ullrich},
year = {2024}
}
```
3. [Learning Distributions on Manifolds with Free-form Flows](https://arxiv.org/abs/2312.09852) on learning distributions on a known manifold:
```bibtex
@article{sorrenson2023learning,
title = {Learning Distributions on Manifolds with Free-form Flows},
author = {Sorrenson, Peter and Draxler, Felix and Rousselot, Armand and Hummerich, Sander and Köthe, Ullrich},
journal = {arXiv preprint arXiv:2312.09852},
year = {2023}
}
```## Installation
The following will install our package along with all of its dependencies:
```bash
git clone https://github.com/vislearn/FFF.git
cd FFF
pip install -r requirements.txt
pip install .
```In the last line, use `pip install -e .` if you want to edit the code.
Then you can import the package via
```python
import fff
```## Basic usage
### Train your architecture
```python
import torch
import fff.loss as lossclass FreeFormFlow(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = torch.nn.Sequential(...)
self.decoder = torch.nn.Sequential(...)model = FreeFormFlow()
optim = ...
data_loader = ...
n_epochs = ...
beta = ...for epoch in range(n_epochs):
for batch in data_loader:
optim.zero_grad()
loss = loss.fff_loss(batch, model.encoder, model.decoder, beta)
loss.backward()
optim.step()
```### Reproduce our experiments
All training configurations from our papers can be found in the `configs/(fff|fif)` directories.
Our training framework is built on [lightning-trainable](https://github.com/LarsKue/lightning-trainable), a configuration wrapper around [PyTorch Lightning](https://lightning.ai/pytorch-lightning). There is no `main.py`, but you can train all our models via the `lightning_trainable.launcher.fit` module.
For example, to train the Boltzmann generator on DW4:
```bash
python -m lightning_trainable.launcher.fit configs/fff/dw4.yaml --name '{data_set[name]}'
```This will create a new directory `lightning_logs/dw4/`. You can monitor the run via `tensorboard`:
```bash
tensorboard --logdir lightning_logs
```When training has finished, you can import the model via
```python
import fffmodel = fff.FreeFormFlow.load_from_checkpoint(
'lightning_logs/dw4/version_0/checkpoints/last.ckpt'
)
```If you want to overwrite the default parameters, you can add `key=value`-pairs after the config file:
```bash
python -m lightning_trainable.launcher.fit configs/fff/dw4.yaml batch_size=128 loss_weights.noisy_reconstruction=20 --name '{data_set[name]}'
```#### Known issues
Training with $E(n)$-GNNs is sometimes unstable. This is usually caught with an assertion in a later step and training is stopped.
In almost all cases, training can be stably resumed from the last epoch checkpoint by passing the `--continue-from [CHECKPOINT]` flag to the training, such as:
```bash
python -m lightning_trainable.launcher.fit configs/fff/dw4.yaml --name '{data_set[name]}' --continue-from lightning_logs/dw4/version_0/checkpoints/last.ckpt
```
This reloads the entire training state (model state, optim state, epoch, etc.) from the checkpoint and continues training from there.### Setup your own training
Start with the config file in `configs/(fff|fif)` that fits your needs best and modify it.
For custom data sets, add the data set to `fff.data`.