{"id":50663850,"url":"https://github.com/jaisidhsingh/nano-dlm","last_synced_at":"2026-06-08T04:33:24.331Z","repository":{"id":355698635,"uuid":"1173275128","full_name":"jaisidhsingh/nano-dlm","owner":"jaisidhsingh","description":"nanoGPT for Diffusion Language Models","archived":false,"fork":false,"pushed_at":"2026-06-02T10:10:04.000Z","size":180,"stargazers_count":2,"open_issues_count":0,"forks_count":0,"subscribers_count":0,"default_branch":"main","last_synced_at":"2026-06-02T11:22:50.570Z","etag":null,"topics":["deep-learning","diffusion-language-models","flax-nnx","gpt2","jax","llms","nano-gpt","openwebtext","optax"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/jaisidhsingh.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2026-03-05T07:29:55.000Z","updated_at":"2026-06-02T10:10:09.000Z","dependencies_parsed_at":null,"dependency_job_id":null,"html_url":"https://github.com/jaisidhsingh/nano-dlm","commit_stats":null,"previous_names":["jaisidhsingh/nano-dlm"],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/jaisidhsingh/nano-dlm","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaisidhsingh%2Fnano-dlm","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaisidhsingh%2Fnano-dlm/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaisidhsingh%2Fnano-dlm/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaisidhsingh%2Fnano-dlm/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/jaisidhsingh","download_url":"https://codeload.github.com/jaisidhsingh/nano-dlm/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaisidhsingh%2Fnano-dlm/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":34048681,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-05-26T15:22:16.424Z","status":"online","status_checked_at":"2026-06-08T02:00:07.615Z","response_time":111,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","diffusion-language-models","flax-nnx","gpt2","jax","llms","nano-gpt","openwebtext","optax"],"created_at":"2026-06-08T04:33:23.536Z","updated_at":"2026-06-08T04:33:24.309Z","avatar_url":"https://github.com/jaisidhsingh.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003cdiv align=\"center\"\u003e\n\n# `nano-dlm` 🧬\n\n\u003ca href=\"LICENSE\"\u003e\u003cimg src=\"https://img.shields.io/badge/License-MIT%202.0-green.svg?style=for-the-badge\" alt=\"License\"\u003e\u003c/a\u003e\n\n\u003e A clean \u0026 extensible **_JAX_** implementation of **diffusion language models** —\n\u003e [nanoGPT](https://github.com/karpathy/nanoGPT) for the diffusion era.\n\n\u003c/div\u003e\n\nImplements **masked/absorbing diffusion** (MDLM) process where each training step randomly masks tokens according to a noise schedule. A bidirectional Transformer learns to predict the original tokens, i.e., the model unmasks tokens at each timestep. Uniform diffusion is coming soon!\n\n## 💡 Why JAX\n\n1. JAX's \"pure function\" nature composes elegantly with `q_sample` for diffusion training.\n2. `jax.jit` and `nnx.jit` compile the entire computation graph via XLA, which means better kernel fusion and more predictable performance. This makes them _much_ stronger than `torch.compile`, which is still only a tracing-based partial compiler.\n3. Explicit `PRNG` splitting for randomness management =\u003e better reproducibility\n\nWe stick purely to the JAX ecosystem, even for data fetching and loading, via the new [`grain`](https://google-grain.readthedocs.io/en/latest/) package.\n\n## 🛠️ Installation\n\nWe need the following packages for this repository, that we recommend be installed in a dedicated `conda` environment.\n\n```bash\nconda create -n nano-dlm python=3.12\n\n# installs latest package versions available for your system\npip install jax jaxlib flax optax tyro tiktoken datasets orbax grain\n\n# for version specificity\npip install -r requirements.txt\n```\n\nOn the other hand, you can also use `uv`\n\n```bash\nuv init .\nsource .venv/bin/activate\nuv add jax jaxlib flax optax tyro tiktoken datasets orbax\nuv sync\n```\n\n## 📀 Data\n\nWe train the diffusion language model on a pre-tokenized subset of `OpenWebText`, very conveniently provided by Neel Nanda on huggingface. You can download and use the dataset easily by\n\n```python\nfrom datasets import load_dataset, load_from_disk\n\ndataset = load_dataset(\"NeelNanda/openwebtext-tokenized-9b\", split=\"train\")\ndataset.save_to_disk(\"your/save/path\") # if you want to save to a specific location\n\n# then load it back in from the saved path\ndataset = load_from_disk(\"your/save/path\")\n```\n\nSpecific information on how the dataset is used can be found in `src/data.py/` and `src/config.py`. Remember to split the dataset into `train` and `val` splits. In our experiments, we use 1M tokens for validation.\n\n## ⚡️ Quick Start: Training\n\nSingle-GPU and multi-GPU use the same script. JAX auto-discovers all visible devices and shards the batch accordingly — no launcher, no code changes.\n\n```bash\n# See available devices\npython -c \"import jax; print(jax.devices())\"\n\n# Single-GPU or multi-GPU — same command\npython train.py\n\n# Control every setting via hierarchical CLI args\npython train.py \\\n  --model.init_seed 123 \\\n  --model.n_layers 6 \\\n  --model.d_model 512 \\\n  --model.n_heads 8 \\\n  --data.shuffle_seed 123 \\\n  --train.seed 123 \\\n  --train.lr 1e-3 \\\n  --train.weight_decay 0.1 \\\n  --train.max_steps 10000 \\\n  --train.batch_size 32 \\\n  --train.grad_acc_steps 8 \\\n  --schedule.kind cosine \\\n  --exp.run_name \"dlm_run\" \\\n  --exp.use_wandb True \\\n  --exp.project_name \"nano-dlm\"\n\n# See every available flag\npython train.py --help\n```\n\nOn multi-GPU the batch is sharded along the data axis via `jax.make_mesh` (pure data parallelism). On a single GPU the mesh degrades to one device with zero overhead.\n\n## 🔍 Architecture\n\nWe provide the option for timestep-conditioning, although the default configuration has it switched off, following the modern implementations of diffusion language model. A brief overview of the architecture and diffusion process is given as follows.\n\n**Parameterisation:** the model predicts **x₀ directly** (not the noise).\nLoss = weighted cross-entropy at masked positions only.\n\n1. **Forward process** `q(xₜ | x₀)` — each token is replaced by `[MASK]` independently\n   with probability `1 - ᾱₜ`, where `ᾱₜ` follows the chosen schedule.\n2. **Training** — given `(xₜ, t)`, the model predicts logits for the original tokens.\n   Loss is MDLM-weighted cross-entropy over masked positions:\n   `L = -E[λₜ · Σᵢ 1[xₜᵢ=[M]] · log p_θ(x₀ᵢ | xₜ, t)]`\n3. **Sampling** — start fully masked `xₜ`, iteratively denoise via DDIM-style\n   ancestral steps using the predicted `x̂₀`.\n\n## ⏳ Noise Schedules\n\n| Flag value | Formula                             | Notes                                        |\n| ---------- | ----------------------------------- | -------------------------------------------- |\n| `cosine`   | `cos²((t/T + 0.008) / 1.008 · π/2)` | Smooth, well-tested (Nichol \u0026 Dhariwal 2021) |\n| `linear`   | `1 − t/T`                           | Simplest baseline                            |\n| `sqrt`     | `1 − √(t/T)`                        | Recommended by MDLM (Shi et al. 2024)        |\n\n```bash\npython train.py --schedule.kind sqrt --schedule.T 1000\n```\n\n## 🔄 Checkpointing\n\nEvery few steps, controllable via the `--exp.save_every` cli arg, we use `orbax` to checkpoint the model and optimizer states. Alongside, the logs upto that step and the full config is saved in `logs.json` and `config.json` respectively. To resume from say step 100 from the example checkpoint below, set `--exp.resume=True` and provide the folder path `nano-dlm-checkpoints/step_100` to `--exp.resume_path`.\n\n```plaintext\nnano-dlm-checkpoints/\n└── step_100/\n    ├── model_state/\n    ├── optimizer_state/\n    ├── logs.json\n    └── config.json\n```\n\n## 🎓 Citation\n\nIf you found this work useful, please cite it as follows.\n\n```bibtex\n@software{singh2026nanodlm,\n  author       = {Singh, Jaisidh},\n  title        = {nano-dlm: A Minimal JAX Implementation of Diffusion Language Models},\n  year         = {2026},\n  publisher    = {GitHub},\n  url          = {https://github.com/jaisidhsingh/nano-dlm}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjaisidhsingh%2Fnano-dlm","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjaisidhsingh%2Fnano-dlm","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjaisidhsingh%2Fnano-dlm/lists"}