{"id":15036431,"url":"https://github.com/apple/ml-mdm","last_synced_at":"2025-10-08T17:15:29.776Z","repository":{"id":251439778,"uuid":"837411728","full_name":"apple/ml-mdm","owner":"apple","description":"Train high-quality text-to-image diffusion models in a data \u0026 compute efficient manner","archived":false,"fork":false,"pushed_at":"2025-03-27T22:19:14.000Z","size":2997,"stargazers_count":491,"open_issues_count":29,"forks_count":36,"subscribers_count":14,"default_branch":"main","last_synced_at":"2025-05-03T20:02:43.824Z","etag":null,"topics":["deep-learning","diffusion-models","large-scale-vision-models","machine-learning","pytorch"],"latest_commit_sha":null,"homepage":"https://machinelearning.apple.com/research/matryoshka-diffusion-models","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/apple.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":"CODE_OF_CONDUCT.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":"security-pre-commit.sh","support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-08-02T23:20:57.000Z","updated_at":"2025-04-19T12:01:39.000Z","dependencies_parsed_at":"2024-12-09T21:29:15.292Z","dependency_job_id":"dd90aacd-078c-411b-b09b-99b5b669d733","html_url":"https://github.com/apple/ml-mdm","commit_stats":null,"previous_names":["apple/ml-mdm"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-mdm","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-mdm/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-mdm/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-mdm/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/apple","download_url":"https://codeload.github.com/apple/ml-mdm/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254276447,"owners_count":22043867,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","diffusion-models","large-scale-vision-models","machine-learning","pytorch"],"created_at":"2024-09-24T20:31:08.807Z","updated_at":"2025-10-08T17:15:24.743Z","avatar_url":"https://github.com/apple.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# ml_mdm - Matryoshka Diffusion Models\n\n`ml_mdm` is a python package for efficiently training high quality text-to-image diffusion models — brought to the public by [Luke Carlson](https://github.com/luke-carlson), [Jiatao Gu](https://github.com/MultiPath), [Shuangfei Zhai](https://github.com/Shuangfei), and [Navdeep Jaitly](https://github.com/ndjaitly).\n\n\n---\n\n\n\u003cdiv align=\"center\"\u003e\n\n\nThis software project accompanies the research paper, [*Matryoshka Diffusion Models*](https://arxiv.org/abs/2310.15111).\n\n\n*Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly*\n\n[[`Paper`](https://arxiv.org/abs/2310.15111)]  [[`BibTex`](#citation)]\n\n\n\n\n![mdm text to image outputs](https://mlr.cdn-apple.com/media/MDM_text_to_image_390ce54fde.png)\n\n\u003c/div\u003e\n\n\n## Table of Contents\n\n| Section | Description |\n| - | - |\n| [Introduction](#introduction) | A brief overview of Matryoshka Diffusion Models |\n| [Installation](#installation) | Start training models and generating samples with `ml_mdm` |\n| [Pretrained Models](#pretrained-models) | Links to download our pretrained models (64, 256, 1024) |\n| [Web Demo](#web-demo) | Generate images with our web UI |\n| [Codebase Structure](#codebase) | An overview of the python module |\n| [Concepts](#concepts) | Core concepts and design principles. |\n| [Tutorial](#tutorials) | Step-by-step training of an MDM model on CC12m |\n\n\n\n## Introduction\n\nDiffusion models are the de facto approach for generating high-quality images and videos, but learning high-dimensional models remains a formidable task due to computational and optimization challenges.\n\n`ml_mdm` is an end-to-end framework for high-resolution image and video synthesis — it is named after our technique: *Matryoshka Diffusion Models*.\n\nRemarkably, we can train a single pixel-space model at resolutions of up to 1024x1024 pixels, demonstrating strong zero-shot generalization using the CC12M dataset, which contains only 12 million images.\n\n![mdm multi scale pipeline](https://mlr.cdn-apple.com/media/MDM_architecture_a813a1ab24.png)\n\n\n\n## Installation\nThe default installation dependencies, as defined in the `pyproject.toml`, are selected so that you can install this library even on a CPU only machine.\n\n\u003e Users have run this codebase with Python 3.9,3.10 and cuda_12, cuda-11.8\n\n```\n\u003e pip install -e .\n```\n\nDevelopers should set up `pre-commit` as well with `pre-commit install`.\n\n### Running Test Cases\n\n```\n\u003e pytest   # run test cases that can work with just cpu\n\u003e pytest  -m ''  # will run all test cases - including ones that require a gpu\n\u003e pytest -m gpu # run only gpu test cases\n```\n\n\n# Pretrained Models\nWe've uploaded model checkpoints to:\n- https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr64/vis_model.pth\n- https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr256/vis_model.pth\n- https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr1024/vis_model.pth\n\n\u003e Note: We are releasing models that were trained on 50M text-image pairs collected from Flickr. In this repo, we provide scripts for downloading [CC12M](https://github.com/google-research-datasets/conceptual-12m) and configs for training equivalent models on CC12M data.\n\nFeel free to download the models or skip further down to train your own. Once a pretrained model is downloaded locally, you can use it in our web demo, pass it as an argument to training, sampling, and more.\n\n```console\nexport ASSET_PATH=https://docs-assets.developer.apple.com/ml-research/models/mdm\n\ncurl $ASSET_PATH/flickr64/vis_model.pth --output vis_model_64x64.pth\ncurl $ASSET_PATH/flickr256/vis_model.pth --output vis_model_256x256.pth\ncurl $ASSET_PATH/flickr1024/vis_model.pth --output vis_model_1024x1024.pth\n```\n\n\n### Web Demo\nYou can run your own instance of the web demo (after downloading the checkpoints) with this command:\n\n```console\ntorchrun --standalone --nproc_per_node=1  ml_mdm/clis/generate_sample.py --port $YOUR_PORT\n```\n\n![image](docs/web_demo.png)\n\n## Codebase\n\n\n### 1. /configs\n\n| module | description |\n| - | - |\n| `configs.dataset_creation` | Configuration file for dataset splitting into train-eval-val pipeline |\n| `configs.datasets` | Datasets for training and evaluation phases of the model |\n| `configs.models` | Configuration files for different resolution models |\n\n\n### 2. /data\n\n| module | description |\n| - | - |\n| `data` | \u003cul\u003e\u003cli\u003e\u003cb\u003ebert.vocab:\u003c/b\u003e BERT-trained dictionary containing tokens and their associated vector representations\u003c/li\u003e\u003cli\u003e\u003cb\u003ec4_wpm.vocab:\u003c/b\u003e C4-trained dictionary containing tokens and their associated vector representations\u003c/li\u003e\u003cli\u003e\u003cb\u003ecifar10.vocab:\u003c/b\u003e CIFAR10-trained dictionary containing tokens and their associated vector representations\u003c/li\u003e\u003cli\u003e\u003cb\u003eimagenet.vocab:\u003c/b\u003e Prompts associated with Imagenet dataset\u003c/li\u003e\u003cli\u003e\u003cb\u003eprompts_cc12m-64x64.tsv:\u003c/b\u003e Prompts associated with cc12m dataset for the 64x64 res. model\u003c/li\u003e\u003cli\u003e\u003cb\u003eprompts_cc12m-256x256.tsv:\u003c/b\u003e Prompts associated with cc12m dataset for the 256x256 res. model\u003c/li\u003e\u003cli\u003e\u003cb\u003eprompts_cifar10-32x32.tsv:\u003c/b\u003e Prompts associated with cifar10 dataset for the 32x32 res. model \u003c/li\u003e\u003cli\u003e\u003cb\u003eprompts_cifar10-64x64.tsv:\u003c/b\u003e Prompts associated with cifar10 dataset for the 64x64 res. model \u003c/li\u003e\u003cli\u003e\u003cb\u003eprompts_demo.tsv:\u003c/b\u003e Extra demo prompts \u003c/li\u003e\u003cli\u003e\u003cb\u003eprompts_imagenet-64px.tsv:\u003c/b\u003e Prompts associated with imagenet dataset for the 64x64 res. model \u003c/li\u003e\u003cli\u003e\u003cb\u003eprompts_WebImage-ALIGN-64px.tsv:\u003c/b\u003e Prompts associated with WebImage-ALIGN dataset for the 64x64 res. model \u003c/li\u003e\u003cli\u003e\u003cb\u003et5.vocab:\u003c/b\u003e t5-trained dictionary containing tokens and their associated vector representations \u003c/li\u003e\u003cli\u003e\u003cb\u003etokenizer_spm_32000_50m.vocab:\u003c/b\u003e SPM-trained dictionary containing tokens and their associated vector representations \u003c/li\u003e\u003c/ul\u003e |\n\n### 3. /docs\n\n| module | description |\n| - | - |\n| `docs` | \u003cul\u003e\u003cli\u003e\u003cb\u003eweb_demo.png:\u003c/b\u003e Screenshot of the web demo of the model\u003c/li\u003e\u003c/ul\u003e |\n\n### 4. /ml_mdm \n\n| module | description |\n| - | - |\n| `ml_mdm.models` | The core model implementations |\n| `ml_mdm.diffusion` | Model pipelines, for example DDPM |\n| `ml_mdm.config` | Connects configuration dataclasses with associated models, pipelines, and clis using [simple parsing](https://github.com/lebrice/SimpleParsing/blob/master/README.md) |\n| `ml_mdm.clis` | All command line tools in the project, the most relevant being `train_parallel.py` |\n| `tests/` | Unit tests and sample training files |\n\n### 5. /tests\n\n| module | description |\n| - | - |\n| `tests.test_files` | Sample files for testing |\n\n# Concepts\n\n\n### ml_mdm.models\nIn the `ml_mdm.models` submodule, we've open sourced our implementations of:\n- U-Nets\n- Nested U-Nets\n\n\n### ml_mdm.config\n`ml_mdm.config` contains the core configuration and cli logic. Many models, clis, and functions in this codebase are configured by passing in a `dataclass` object. We use [SimpleParsing](https://github.com/lebrice/SimpleParsing) to dynamically create clis and allow passing in yaml `config` representations with `--config_path`.\n\n\n\u003e In essence, `simple_parsing` will convert all passed cli arguments and yaml files into clean configuration classes like `ml_mdm.reader.ReaderConfig`, `ml_mdm.diffusion.DiffusionConfig`.\n\n\n`ml_mdm.config` stores a global mapping of names to classes in `MODEL_REGISTRY`, `MODEL_CONFIG_REGISTRY`, `PIPELINE_REGISTRY`, and `PIPELINE_CONFIG_REGISTRY`.\n\n`MODEL_REGISTRY` and `PIPELINE_REGISTRY` store information as shown in the following example:\n\n\u003e *_CONFIG_REGISTRY[architecture name][\"model\"] = model name\n\n\u003e *_CONFIG_REGISTRY[architecture name][\"config\"] = configuration class\n\nMODEL_CONFIG_REGISTRY and PIPELINE_CONFIG_REGISTRY store information as shown in the following example: \n\u003e *_CONFIG_REGISTRY[architecture name][\"model\"] = model name\n\n\u003e *_CONFIG_REGISTRY[architecture name][\"config\"] = configuration class\n\n\narchitecture name and model name are passed into ml_mdm.config through the function parameter *names. where *names points to \"architecture name\", \"model name\"\n\n\n\n# Tutorials\n\n## Generate Your Own Images With Pretrained Checkpoints\n\nOnce you've installed `ml_mdm`, download these checkpoints into the repo's directory.\n\n```console\ncurl https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr64/vis_model.pth --output vis_model_64x64.pth\ncurl https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr256/vis_model.pth --output vis_model_256x256.pth\n```\n\nThe web demo will load each model with a corresponding configuration:\n- `vis_model_64x64.pth` will be loaded with the settings from `configs/models/cc12m_64x64.yaml`\n- `vis_model_256x256.pth` will be loaded with the settings from `configs/models/cc12m_256x256.yaml`\n- `vis_model_1024x1024.pth` will be loaded with the settings from `configs/models/cc12m_1024x1024.yaml`\n\nIn the demo, you can change a variety of settings and peek into the internals of the model. Set the port you'd like to use by swapping in `$YOUR_PORT` and then run:\n\n```console\ntorchrun --standalone --nproc_per_node=1  ml_mdm/clis/generate_sample.py --port $YOUR_PORT\n```\n\n## Training on Dummy Data\nIf you just want to step through the process of training a model and running a pipeline without downloading a large dataset, we've put together a minimal example for you. It uses the dummy data from `tests/test_files/`\n\n\u003e Feel free to try changing a variety of --args either directly in the cli or by editing the config yaml file\n\n```console\ntorchrun --standalone --nproc_per_node=1 ml_mdm/clis/train_parallel.py \\\n --file-list=tests/test_files/sample_training_0.tsv \\\n --multinode=0 \\\n  --output-dir=outputs    --config_path configs/models/cc12m_64x64.yaml \\\n  -num_diffusion_steps=10 \\\n\t--num-training-steps=10\n```\n\nYou should see a `outputs/vis_model_000100.pth` file. Now lets do something a bit more meaningful:\n\n\n## Lets train an MDM model on CC12m\n\n### 1. Data Prep:\n\n**(OPTIONAL) Download the first 1K files of CC12m with this sample argument**\n\n\u003e The script is based on [img2dataset's CC12M script](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md).\n\n```console\ncurl https://storage.googleapis.com/conceptual_12m/cc12m.tsv | head -n 1000 \u003e cc12m_index.tsv\n\n# Add headers to the file\nsed -i '1s/^/url\\tcaption\\n/'  cc12m_index.tsv\n```\n\u003e  Note: if you want all of cc12m, remove `| head -n 1000` from the call\n\nThen prepare and split into train/validation\n\n\u003e This script requires `img2dataset`, either run `pip install '.[data_prep]'` or just `pip install img2dataset`\n\n```console\npython3 -m ml_mdm.clis.scrape_cc12m \\\n  --cc12m_index cc12m_index.tsv \\\n  --cc12m_local_dir cc12m_download\n```\nAfter running this command you will see the following files:\n```console\ntraining.0.tsv # train index file\nvalidation.tsv # validation index file\ncc12m_download/\n   00000.parquet  00000.tar  00000.tsv  00000_stats.json  validation.tsv\n   00001.parquet ....\n```\n\n### 2. Train\nNow that we have our training file, we can select a model config and pass any additional training arguments:\n\n```console\n# Modify torchrun arguments to fit your GPU setup\ntorchrun --standalone --nproc_per_node=8 ml_mdm/clis/train_parallel.py \\\n  --file-list=training_0.tsv \\\n  --multinode=0 --output-dir=/mnt/data/outputs \\\n  --config_path configs/models/cc12m_64x64.yaml \\\n  --num-training-steps=100   --warmup-steps 10\n```\n\u003e Note: `configs/models/cc12m_64x64.yaml` contains many more arguments, check it out for more details.\n\n\u003e If you've downloaded a pretrained model, you can set the `--pretrained-vision-file` argument to point to its location on disk\n\nOnce training completes, you'll find the model in the folder defined by the --output-dir argument:\n```console\n2024-07-22:17:58:46,649 INFO     [model_ema.py:33] Saving EMA model file: /mnt/data/outputs/vis_model_000100.pth\n2024-07-22:17:58:47,448 INFO     [unet.py:794] Saving model file: /mnt/data/outputs/vis_model_noema_000100.pth\n```\n\n\n### 3. Sample from the model\nNow that we have a trained model, we can generate samples from the diffusion model:\n```console\ntorchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_batch.py \\\n  --config_path configs/models/cc12m_64x64.yaml \\\n  --min-examples 3 --test-file-list validation.tsv \\\n  --sample-image-size 64 --model-file /mnt/data/outputs/vis_model_000100.pth\n```\n\n\u003e If you want to skip the training step, you can update `--model-file` to point to one of our pretrained models\n\n\n\n### Dataset Storage\n\nFor long term storage, you can optionally upload your data to `s3://{your_bucket}/datasets/{datasetname}/*.[tar,tsv]`.\n\nThen update `configs/datasets/cc12m.yaml` to point to your s3 paths.\n\n```yaml\n# configs/datasets/cc12m.yaml\ntrain:\n  files:\n    - s3://mlx/datasets/cc12m-64x64/images_00.*.tsv\neval:\n  files:\n    - s3://mlx/datasets/cc12m-64x64/validation.tsv\n```\n\n```yaml\n# configs/datasets/reader_config.yaml\nreader_config:\n  append_eos: true\n  bucket: ${your_bucket} # add your s3 bucket\n  endpoint_url: None # boto will automatically infer the endpoint\n```\n\n\nThen you can use our dataset download helper:\n```console\npython -m ml_mdm.clis.download_tar_from_index \\\n  --dataset_config_file configs/datasets/cc12m.yaml \\\n  --subset train --download_tar\n\npython -m ml_mdm.clis.download_tar_from_index \\\n  --dataset_config_file configs/datasets/cc12m.yaml \\\n  --subset eval --download_tar\n```\n\n### S3 Dataset Selection\n\nTake a look at `configs/datasets/cc12m.yaml`.\n\nThe code allows for multiple regular expressions to be provided. Keep in mind that the\nregular expressions are not globs -- they are regular expressions from the python re library.\nSo if you wanted to use only 100 of the 1000 tar files in WebImage for training you can\ndo the following:\n\n```yaml\ntrain:\n  files:\n    - s3://mlx/datasets/example-dataset-100M_64px/example-dataset-100M-00[0-1]..-[0-9]*-of-01000.tsv\neval:\n  files:\n    - s3://mlx/datasets/example-dataset-100M_64px/validation.tsv\n```\n\nYou can also mix and match the files. So if you wanted to merge CC12m and imagenet you could\ncreate a new yaml file with the following contents:\n\n```yaml\ntrain:\n  files:\n    - s3://mlx/datasets/imagenet-64px/imagenet-train-000??-of-00100.tsv\n    - s3://mlx/datasets/cc12m-64x64/images_00.*.tsv\neval:\n  files:\n    - s3://mlx/datasets/cc12m-64x64/validation.tsv\n```\n\n### Dataset Structure\nThe S3 Bucket contains a series of files in this format, take a look at `ml_mdm/clis/scrape_cc12m.py` to generate your own.\n```console\n2023-04-01 01:31:30   36147200 images_00000.tar\n2023-05-10 11:34:49    1108424 images_00000.tsv\n2023-04-01 01:31:26   36454400 images_00001.tar\n2023-05-10 11:34:49    1109588 images_00001.tsv\n2023-04-01 01:31:53   36116480 images_00002.tar\n...\n```\n\nMinimal representations of these files can be found at `tests/test_files/`.\n\n\n## Citation\nIf you find our work useful, please consider citing us as:\n```\n@misc{gu2023matryoshkadiffusionmodels,\n      title={Matryoshka Diffusion Models},\n      author={Jiatao Gu and Shuangfei Zhai and Yizhe Zhang and Josh Susskind and Navdeep Jaitly},\n      year={2023},\n      eprint={2310.15111},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV},\n      url={https://arxiv.org/abs/2310.15111},\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fapple%2Fml-mdm","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fapple%2Fml-mdm","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fapple%2Fml-mdm/lists"}