Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/devrimcavusoglu/dmd
PyTorch implementation of One-step Diffusion with Distribution Matching Distillation
https://github.com/devrimcavusoglu/dmd
Last synced: 3 months ago
JSON representation
PyTorch implementation of One-step Diffusion with Distribution Matching Distillation
- Host: GitHub
- URL: https://github.com/devrimcavusoglu/dmd
- Owner: devrimcavusoglu
- License: other
- Created: 2024-03-30T17:40:51.000Z (10 months ago)
- Default Branch: main
- Last Pushed: 2024-06-01T14:56:39.000Z (8 months ago)
- Last Synced: 2024-10-12T14:15:03.534Z (4 months ago)
- Language: Jupyter Notebook
- Size: 11.8 MB
- Stars: 15
- Watchers: 3
- Forks: 6
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-diffusion-categorized - [Code
README
# One-step Diffusion with Distribution Matching Distillation
![Methodology](assets/methodology.png)
A PyTorch implementation of the paper [One-step Diffusion with Distribution Matching Distillation](https://arxiv.org/abs/2311.18828). This
project codebase is mostly based on the codebase of [EDM from NVLabs](https://github.com/NVlabs/edm) and built on top of it with according
modifications.Note that this is an unofficial reimplementation study for the paper, and in this codebase we focused on experimenting
with CIFAR-10 dataset reproduce the results. However, the technique may be applicable to other datasets with minor
adjustments.## Pretrained Models
You can access to the experimentation/training logs on [neptune](https://app.neptune.ai/o/metu-ms/org/DMD).
| | Model |
|---------------------|----------------------------------------------------------------------------------------------------------------------------------------|
| CIFAR-10-32x32-cond | [https://huggingface.co/Devrim/dmd-cifar-10-cond](https://huggingface.co/Devrim/dmd-cifar-10-cond/resolve/main/model.pt?download=true) |## Setup
Create a conda environment with the configuration file, and activate the environment when necessary.
```shell
conda env create -f environment.yml
```You can access the CLI by,
```shell
python -m dmd --help
```### Models
DMD method is an application of distillation, and thus requires a teacher model. The teacher diffusion model
used in the paper was [EDM models](https://github.com/NVlabs/edm). Specifically, for CIFAR-10 we will focus on
a conditioned model. You can see pretrained EDM Models [here](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/).### Dataset
Download the distillation dataset by,
```shell
bash scripts/download_data.sh
```For reference, sample images from both CIFAR-10 datasets and from the base conditioned pretrained EDM model (CIFAR-10) are shown below:
![Methodology](assets/dataset_samples.png)
## Training
Start training by running
```shell
python -m dmd train --model-path https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl --data-path /path/to/hdf5_data --output-dir /path/to/output-dir --epochs 2 --batch-size 32
```To see all training arguments run
```shell
python -m dmd train --help
```### Output Images
The `output_dir` used for saving checkpoints of the model, also used for saving images sampled from the model during training. The images saved are a grid of size `(5, batch_size)` where the rows correspond to `(x, x_real, x_pred, x_ref, y_ref)` respectively. We respect to the paper naming convention for these:
z: Random sample from N(0,1).
x: G(z). (row=1)
s: edm_sigma(t) # t ~ U(0,1)
noisy_x: forward_diffusion(x, s)
x_real: mu_real(noisy_x, t) (row=2) # 'pred_real_image' in the paper
x_fake: mu_fake(noisy_x, t) (row=3) # 'pred_fake_image' in the paper
z_ref: Paired latent from the synthetic dataset.
x_ref: G(z_ref). (row=4)
y_ref: Paired image sampled from the base pretrained model. (row=5)
### Logging to Neptune
Create a `neptune.cfg` file in the project root. The file content should look like this:
```ini
[credentials]
project=
token=
```Then, you can use `--log-neptune` flag to automatically log metrics to your neptune project.
## Generation
Once you trained the one-step generator. You can generate samples from it as follows:```python
from dmd.generate import DMDGenerator
from dmd.utils.common import display_samplesgen = DMDGenerator(network_path="/path/to/model_checkpoint.pt")
samples = gen.generate_batch(seeds=list(range(25)), class_ids=0)
display_samples(samples).show()
```## Assumptions
- Hyperparameters are explicitly stated, but there's no information for which model they are used. We assumed for both model when there is no additional information. (optimizer, lr)
## Development
For convenience add the project root to PYTHONPATH, earlier conda versions support this by `develop` command, run
```shell
conda develop /path/to/project_root
```However, `conda develop` is deprecated for recent versions, you can manually add the project root to PYTHONPATH by
```shell
export PYTHONPATH="${PYTHONPATH}:/path/to/project_root"
```### Code Formatting
To format the codebase, run
```shell
python -m scripts.run_code_style format
```To check whether the codebase is well-formatted, run
```shell
python -m scripts.run_code_style check
```## License
Copyright © 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\
Copyright © 2024, Devrim Cavusoglu & Ahmet Burak YıldırımThis work contains the implementation of the methodology and study presented in the *One-step Diffusion with
Distribution Matching Distillation* paper. Also as the building block of the codebase, [NVLabs/edm](https://github.com/NVlabs/edm) is
used, modified and adapted accordingly when necessary. As the original license of the underlying framework (edm)
dictates (ShareAlike), this derived work and all the source are licensed under the same license
of [Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).