Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kuleshov-group/mdlm
Simplified Masked Diffusion Language Model
https://github.com/kuleshov-group/mdlm
diffusion-models language-model text
Last synced: 2 months ago
JSON representation
Simplified Masked Diffusion Language Model
- Host: GitHub
- URL: https://github.com/kuleshov-group/mdlm
- Owner: kuleshov-group
- License: apache-2.0
- Created: 2024-06-07T15:01:36.000Z (7 months ago)
- Default Branch: master
- Last Pushed: 2024-08-25T22:58:50.000Z (5 months ago)
- Last Synced: 2024-08-26T00:14:02.609Z (5 months ago)
- Topics: diffusion-models, language-model, text
- Language: Python
- Homepage: https://s-sahoo.com/mdlm/
- Size: 163 KB
- Stars: 145
- Watchers: 8
- Forks: 10
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
- Citation: CITATION.cff
Awesome Lists containing this project
- awesome-adaptive-computation - pytorch code
README
# [Simple and Effective Masked Diffusion Language Models](http://arxiv.org/abs/2406.07524)
By [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Marianne Arriola](https://mariannearriola.github.io), [Yair Schiff](https://yair-schiff.github.io), [Aaron Gokaslan](https://skylion007.github.io), [Edgar Marroquin](https://emarro.github.io),
[Justin T Chiu](https://justinchiu.netlify.app), [Alexander Rush](https://rush-nlp.com), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)[![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2406.07524)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/)
[![YouTube](https://img.shields.io/badge/YouTube-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/WjAUX23vgfg?si=lI-qiDFqh25qtnQ8)
[![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](https://s-sahoo.com/mdlm/)
[![deploy](https://img.shields.io/badge/Huggingface%20-MDLM%20-blue)](https://huggingface.co/collections/kuleshov-group/mdlm-6671bee1cc71f0dce4f2d00a)![graphical_abstract_updated_2](https://github.com/s-sahoo/mdlm/assets/16799748/b0cab23a-d966-45fa-a3ad-be972b23a98a)
We introduce *MDLM*, a **M**asked discrete **D**iffusion **L**anguage **M**odel that features
a novel (SUBS)titution based
parameterization which simplifies the absorbing state diffusion
loss to a mixture of
classical masked language modeling losses. In doing so, we achieve
SOTA perplexity numbers on LM1B and OpenWebText among diffusion models while achiving competitive zero-shot perplexity with SOTA AR models on numerous datasets. We provide a demo in this [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/) notebook and a video tutorial here:In this repo, we release:
* **The MDLM framework.**
1. SUBStitution based parameterization
2. Simplified loss calculation for masked diffusion processes
* **Baseline implementations** [[Examples]](#baselines):
1. Autoregressive model that matches the SOTA AR performance on LM1B.
2. Score Entropy Based Discrete Diffusion [SEDD](https://arxiv.org/abs/2310.16834).
3. An efficient implementation of the absorbing state [D3PM](https://arxiv.org/abs/2107.03006) that beats the previous SOTA text diffusion model SEDD on LM1B.
* **Samplers**
1. Ancestral sampling as proposed in D3PM.
2. Analytic sampler as proposed in SEDD.
3. Our proposed efficient sampler that
- makes MDLM **~3-4x** faster than the existing diffusion models. [[Example]](#sample-gen)
- supports semi-autoregressive (SAR) generation. [[Example]](#semi-ar-gen)
## Code Organization
1. ```main.py```: Routines for training and evaluation
2. ```noise_schedule.py```: Noise schedules
3. ```diffusion.py```: Forward/reverse diffusion
4. ```dataloader.py```: Dataloaders
5. ```utils.py```: LR scheduler, logging, `fsspec` handling
6. ```models/```: Denoising network architectures. Supports [DiT](https://arxiv.org/abs/2212.09748), AR transformer, and [Mamba](https://arxiv.org/abs/2312.00752)
7. ```configs/```: Config files for datasets/denoising networks/noise schedules/LR schedules
8. ```scripts/```: Shell scripts for training/evaluation## Getting started in this repository
To get started, create a conda environment containing the required dependencies.
```bash
conda env create -f requirements.yaml
conda activate mdlm
```Create the following directories to store saved models and slurm logs:
```bash
mkdir outputs
mkdir watch_folder
```
and run the training as a batch job:
```bash
sbatch scripts/train_owt_mdlm.sh
```### Checkpoints
We have uploaded MDLM model trained on OpenWebText for 1M training steps to the Huggingface hub 🤗:
[kuleshov-group/mdlm-owt](https://huggingface.co/kuleshov-group/mdlm-owt)
Furthermore, we have released the checkpoints for the AR and SEDD baselines trained on OpenWebText in this [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing).## Reproducing Experiments
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the [`main.py`](./main.py) script.
We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`scrips/`](./scripts) directory.### Generate Samples
The argument to `sampling.predictor` specifies the sampler which takes one of the following values:
* `ddpm_cache`: our proposed sampler that's **~3-4x** faster than the samplers propsed in D3PM and SEDD.
* `ddpm`: Ancestral sampling proposed in D3PM.
* `analytic`: Analytic sampler proposed in SEDD.In the following table we report wall clock time to generate 64 samples on a single A5000 GPU with `batch_size=1`. $T$ denotes the time discretization of the reverse process.
| | $T=5k (\downarrow)$ | $T=10k (\downarrow)$ |
|-------------------------|---------------------|----------------------|
| **SEDD** | 127.1 | 229.3 |
| **MDLM** + `ddpm` | 113.8 | 206.6 |
| **MDLM** +`ddpm_cache` | **40.1** | **60.4** |To generate samples from a pre-trained model use one of the following commands:
#### Huggingface model
```bash
python main.py \
mode=sample_eval \
eval.checkpoint_path=kuleshov-group/mdlm-owt \
data=openwebtext-split \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=1000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=10 \
backbone=hf_dit
```
#### Local checkpoint
```bash
python main.py \
mode=sample_eval \
eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
data=openwebtext-split \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=10000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=1 \
backbone=dit
```### Semi-AR sample generation
MDLM can also generate samples of arbitrary length in a semi-autoregressive (SAR) manner.
We generate 200 sequences of length 2048 tokens on a single `3090` GPU and evaluate generative perplexity under a pre-trained GPT-2 model. In the below table we find that in addition to achieving better generative perplexity, MDLM enables **25-30x** faster SAR decoding relative to [SSD-LM](https://arxiv.org/abs/2210.17432).| | Gen. PPL ($\downarrow$) | Sec/Seq ($\downarrow$) |
|---------------------|-------------------------|------------------------|
| **SSD-LM** | 35.43 | 2473.9 |
| **MDLM** +`ddpm_cache` | **27.18** | **89.3** |*Gen. PPL: Generation Perplexity, Sec/Seq: Seconds per Sequence*
```bash
python main.py \
mode=sample_eval \
eval.checkpoint_path=kuleshov-group/mdlm-owt \
data=openwebtext-split \
parameterization=subs \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=1000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=2 \
sampling.semi_ar=True \
sampling.stride_length=512 \
sampling.num_strides=2 \
backbone=hf_dit
```### Train
To train MDLM from scratch on OpenWebText use the following command:
```
python main.py \
model=small \
data=openwebtext-split \
wandb.name=mdlm-owt \
parameterization=subs \
model.length=1024 \
eval.compute_generative_perplexity=True \
sampling.steps=1000
```
The arguments `loader.batch_size` and `loader.eval_batch_size` allow you to control the global batch size and the batch size per GPU. If `loader.batch_size * num_gpus` is less than the global batch size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command: `sbatch scripts/train_owt_mdlm.sh`. The slurm scripts to train the Auto-regressive and SEDD baselines are as follows respectively: [`scripts/train_lm1b_ar.sh`](scripts/train_lm1b_ar.sh), [`scripts/train_owt_sedd.sh`](scripts/train_owt_sedd.sh).### Eval
To compute test perplexity, use `mode=ppl_eval`. Example scripts provided in `scripts/`. An example command for perplexity evaluation on OpenWebText is:
```
python main.py \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small \
parameterization=subs \
backbone=dit \
model.length=1024 \
eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
+wandb.offline=true
```### Baseline evaluation
We release the checkpoints for the baselines: SEDD and AR trained on OpenWebText in this [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing). Download the checkpoints: `ar.ckpt`, `sedd.ckpt` and use the following commands to compute test perplexity:
#### AR
```bash
python main.py \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small-ar \
parameterization=ar \
backbone=ar \
model.length=1024 \
eval.checkpoint_path=/path/to/checkpoint/ar.ckpt \
+wandb.offline=true
```
#### SEDD
```bash
python main.py \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small \
parameterization=sedd \
backbone=dit \
model.length=1024 \
eval.checkpoint_path=/path/to/checkpoint/sedd.ckpt \
time_conditioning=True \
sampling.predictor=analytic \
+wandb.offline=true
```### Acknowledgements
This repository was built off of [SEDD](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion).## Citation
```
@misc{sahoo2024simple,
title={Simple and Effective Masked Diffusion Language Models},
author={Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
year={2024},
eprint={2406.07524},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```