{"id":18317329,"url":"https://github.com/compvis/zigma","last_synced_at":"2025-04-12T16:41:40.056Z","repository":{"id":228941260,"uuid":"775300536","full_name":"CompVis/zigma","owner":"CompVis","description":"A PyTorch implementation of the paper  \"ZigMa: A DiT-Style Mamba-based Diffusion Model\" (ECCV 2024)","archived":false,"fork":false,"pushed_at":"2025-03-17T15:00:52.000Z","size":32307,"stargazers_count":303,"open_issues_count":6,"forks_count":21,"subscribers_count":11,"default_branch":"main","last_synced_at":"2025-04-12T01:58:12.557Z","etag":null,"topics":["diffusion-models","flow-matching","mamba","state-space-model","stochastic-interpolant","zigma"],"latest_commit_sha":null,"homepage":"https://taohu.me/zigma","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/CompVis.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE.txt","code_of_conduct":null,"threat_model":null,"audit":null,"citation":"CITATION.cff","codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-03-21T06:01:01.000Z","updated_at":"2025-04-09T08:26:13.000Z","dependencies_parsed_at":"2024-03-21T09:24:36.568Z","dependency_job_id":"e5a4c32b-b875-4cbd-97f6-9c4392ff8d51","html_url":"https://github.com/CompVis/zigma","commit_stats":null,"previous_names":["compvis/zigma"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/CompVis%2Fzigma","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/CompVis%2Fzigma/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/CompVis%2Fzigma/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/CompVis%2Fzigma/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/CompVis","download_url":"https://codeload.github.com/CompVis/zigma/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248598979,"owners_count":21131193,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["diffusion-models","flow-matching","mamba","state-space-model","stochastic-interpolant","zigma"],"created_at":"2024-11-05T18:05:46.418Z","updated_at":"2025-04-12T16:41:40.025Z","avatar_url":"https://github.com/CompVis.png","language":"Python","readme":"#  ZigMa: A DiT-style Zigzag Mamba Diffusion Model  (ECCV 2024)\n\n**ECCV 2024**\n\n**[Oral Talk in ICML 2024 Workshop on Long Context Foundation Models (LCFM)](https://icml.cc/virtual/2024/39058)**\n\n\n\nThis repository represents the official implementation of the paper titled \"ZigMa: A DiT-style Zigzag Mamba Diffusion Model (ECCV 2024)\".\n\n[![Website](doc/badges/badge-website.svg)](https://taohu.me/zigma)\n[![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2403.13802)\n[![Hugging Face Model](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-green)](https://huggingface.co/taohu/zigma)\n[![GitHub](https://img.shields.io/github/stars/CompVis/zigma?style=social)](https://github.com/CompVis/zigma)\n[![GitHub closed issues](https://img.shields.io/github/issues-closed/CompVis/zigma?color=success\u0026label=Issues)](https://github.com/CompVis/zigma/issues?q=is%3Aissue+is%3Aclosed) \n[![Twitter](https://img.shields.io/badge/Twitter-🔥%2020k%2B120k%20views-b31b1b.svg?style=social\u0026logo=twitter)](https://twitter.com/_akhaliq/status/1770668624392421512)\n[![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0)\n![visitors](https://visitor-badge.laobi.icu/badge?page_id=CompVis/zigma)\n\n[Vincent Tao Hu](http://taohu.me),\n[Stefan Andreas Baumann](https://scholar.google.de/citations?user=egzbdnoAAAAJ\u0026hl=en),\n[Ming Gui](https://www.linkedin.com/in/ming-gui-87b76a16b/?originalSubdomain=de),\n[Olga Grebenkova](https://www.linkedin.com/in/grebenkovao/),\n[Pingchuan Ma](https://www.linkedin.com/in/pingchuan-ma-492543156/),\n[Johannes Schusterbauer](https://www.linkedin.com/in/js-fischer/ ),\n[Björn Ommer](https://ommer-lab.com/people/ommer/ )\n\nWe present ZigMa, a scanning scheme that follows a zigzag pattern, considering both spatial continuity and parameter efficiency. We further adapt this scheme to video, separating the reasoning between spatial and temporal dimensions, thus achieving efficient parameter utilization. Our design allows for greater incorporation of inductive bias for non-1D data and improves parameter efficiency in diffusion models.\n\n\n## 🎓 Citation\n\nPlease cite our paper:\n\n```bibtex\n@InProceedings{hu2024zigma,\n      title={ZigMa: A DiT-style Zigzag Mamba Diffusion Model},\n      author={Vincent Tao Hu and Stefan Andreas Baumann and Ming Gui and Olga Grebenkova and Pingchuan Ma and Johannes Schusterbauer and Björn Ommer},\n      booktitle = {ECCV},\n      year={2024}\n}\n```\n\n## :white_check_mark: Updates\n* **` May. 24th, 2024`**:  🚀🚀🚀 New checkpoints for FacesHQ1024, landscape1024, Churches256 datasets.\n* **` April. 6th, 2024`**: Support for FP16 training, and checkpoint function, and torch.compile to achieve better memory utilization and speed boosting.\n* **` April. 2th, 2024`**: Main code released.\n\n\n\n\n![landscape](doc/landscape_1.png)\n![faceshq](doc/faceshq_0.png)\n![teaser](doc/teaser_3col.png)\n\n\n\n\n## Quick Demo\n\n\n```python\nfrom model_zigma import ZigMa\n\nimg_dim = 32\nin_channels = 3\n\nmodel = ZigMa(\nin_channels=in_channels,\nembed_dim=640,\ndepth=18,\nimg_dim=img_dim,\npatch_size=1,\nhas_text=True,\nd_context=768,\nn_context_token=77,\ndevice=\"cuda\",\nscan_type=\"zigzagN8\",\nuse_pe=2,\n)\n\nx = torch.rand(10, in_channels, img_dim, img_dim).to(\"cuda\")\nt = torch.rand(10).to(\"cuda\")\n_context = torch.rand(10, 77, 768).to(\"cuda\")\no = model(x, t, y=_context)\nprint(o.shape)\n```\n\n\n\n### Improved Training Performance\nIn comparison to the original implementation, we implement a selection of training speed acceleration and memory saving features including gradient checkpointing\n| torch.compile | gradient checkpointing | training speed | memory |\n| :-----------: | :--------------------: | :------------: | :----: |\n|       ❌       |           ❌            | 1.05 iters/sec |  18G   |\n|       ❌       |           ✔            | 0.93 steps/sec |   9G   |\n|       ✔       |           ❌            | 1.8 iters/sec  |  18G   |\n\ntorch.compiles is for indexing operation: [here](https://github.com/CompVis/zigma/blob/1e78944ebce400d34a12efd4baba1daad0fae9f3/dis_mamba/mamba_ssm/modules/mamba_simple.py#L55) and [here](https://github.com/CompVis/zigma/blob/1e78944ebce400d34a12efd4baba1daad0fae9f3/dis_mamba/mamba_ssm/modules/mamba_simple.py#L60)\n\n\n\n## 🚀  Training\n\n\n#### CelebaMM256 \n\n\nSweep-2, 1GPU\n```bash\naccelerate launch  --num_processes 1 --num_machines 1  --mixed_precision fp16    train_acc.py  model=sweep2_b1  use_latent=1   data=celebamm256_uncond  ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=8   note=_ \n```\n\nZigzag-8, 1GPU\n```bash\nCUDA_VISIBLE_DEVICES=4 accelerate launch  --num_processes 1 --num_machines 1  --mixed_precision fp16  --main_process_ip 127.0.0.1 --main_process_port 8868  train_acc.py  model=zigzag8_b1  use_latent=1   data=celebamm256_uncond  ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=4   note=_ \n```\n\n\n\n\n####  UCF101\n\nBaseline, multi-GPU\n```bash\nCUDA_VISIBLE_DEVICES=\"0,1,2,3\" accelerate launch  --num_processes 4 --num_machines 1 --multi_gpu --mixed_precision fp16  --main_process_ip 127.0.0.1 --main_process_port 8868  train_acc.py  model=3d_sweep2_b2  use_latent=1 data=ucf101  ckpt_every=10_000  data.sample_fid_n=20_0 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=4   note=_ \n```\n\nFactorized 3D Zigzag: sst, multi-GPU\n```bash\nCUDA_VISIBLE_DEVICES=\"0,1,2,3\" accelerate launch  --num_processes 4 --num_machines 1 --multi_gpu --mixed_precision fp16  --main_process_ip 127.0.0.1 --main_process_port 8868  train_acc.py  model=3d_zigzag8sst_b2  use_latent=1 data=ucf101  ckpt_every=10_000  data.sample_fid_n=20_0 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=4   note=_ \n```\n\n\n\n\n## 🚀 Sampling\n\n\n#### FacesHQ 1024\n\nYou can directly download the model in this repository. You also can download the model in python script:\n```python\nfrom huggingface_hub import hf_hub_download\n\nhf_hub_download(\n        repo_id=\"taohu/zigma\",\n        filename=\"faceshq1024_0090000.pt\",\n        local_dir=\"./checkpoints\",\n    )\n```\n[huggingface model repo](https://huggingface.co/taohu/zigma)\n\n|Dataset | Checkingpoint|Model |data|\n|---|---|---|---|\n|faceshq1024.pt|faceshq1024_0090000.pt|model=s1024_zigzag8_b2_old|data=facehq_1024|\n|landscape1024|landscape1024_0210000.pt|model=s1024_zigzag8_b2_old|data=landscapehq_1024|\n|Churches256|churches256_0280000.pt|model=zigzag8_b1_pe2|data=churches256|\n|Coco256|zigzagN8_b1_pe2_coco14_bs48_0400000.pt|mode=zigzag8_b1_pe2|data=coco14 (31.0) | \n\n\n\n1GPU sampling \n```bash\nCUDA_VISIBLE_DEVICES=\"2\" accelerate launch  --num_processes 1 --num_machines 1     sample_acc.py  model=s1024_zigzag8_b2_old  use_latent=1   data=facehq_1024  ckpt_every=10_000 data.sample_fid_n=5_000 data.sample_fid_bs=4 data.sample_fid_every=10_000  data.batch_size=8  sample_mode=ODE likelihood=0  num_fid_samples=5_000 sample_debug=0  ckpt=checkpoints/faceshq1024_0060000.pt  \n```\nThe sampled images will be saved both on wandb (disable with `use_wandb=False`) and directory **samples/**\n\n\n##  🛠️  Environment Preparation\n\n\ncuda==11.8,python==3.11, torch==2.2.0, gcc==11.3(for SSM enviroment)\n\npython=3.11 # support the torch.compile for the time being. https://github.com/pytorch/pytorch/issues/120233#issuecomment-2041472137\n```bash\nconda create -n zigma python=3.11\nconda activate zigma\nconda install -c \"nvidia/label/cuda-11.8.0\" cuda-toolkit\nconda install pytorch torchvision  pytorch-cuda=11.8 -c pytorch -c nvidia\npip install  torchdiffeq  matplotlib h5py timm diffusers accelerate loguru blobfile ml_collections wandb\npip install hydra-core opencv-python torch-fidelity webdataset einops pytorch_lightning\npip install torchmetrics --upgrade\npip install opencv-python causal-conv1d\ncd dis_causal_conv1d \u0026\u0026 pip install -e . \u0026\u0026 cd ..\ncd dis_mamba \u0026\u0026 pip install -e . \u0026\u0026 cd ..\npip install moviepy imageio #wandb.Video() need it\npip install  scikit-learn --upgrade \npip install transformers==4.36.2\npip install numpy-hilbert-curve # (optional) for generating the hilbert path\npip install av    # (optional)  to use the ucf101 frame extracting\npip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers  #for FDD metrics\n```\n\nInstalling Mamba may cost a lot of effort. If you encounter problems, this [issues in Mamba](https://github.com/state-spaces/mamba/issues) may be very helpful.\n\n\nCreate a file under the directory ./config/wandb/default.yaml:\n\n```yaml\nkey: YOUR_WANDB_KEY\nentity: YOUR_ENTITY\nproject: YOUR_PROJECT_NAME\n```\n\n## Q\u0026A\n\n- If you meeet some issues for installing ssm, maybe you can find solution here: [https://github.com/state-spaces/mamba/issues](https://github.com/state-spaces/mamba/issues)\n\n## 📷  Dataset Preparation\n\n\nDue to privacy issue, we cannot share the dataset here, basically, we use MM-CelebA-HQ-Dataset from [https://github.com/IIGROUP/MM-CelebA-HQ-Dataset](https://github.com/IIGROUP/MM-CelebA-HQ-Dataset), we organize into the format of [webdataset](https://webdataset.github.io/) to enable the scalable training in multi-gpu.\n\n\nWebdataset Format: \n- image: image.jpg # ranging from [-1,1], shape should be [3,256,256]\n- latent: img_feature256.npy # latent feature for latent generation, shape should be [4,32,32]\n\n\n\n\nThe dataset we use include:\n- MM-CelebA-HQ for 256 and 512 resolution training \n- FacesHQ1024 for 1024 resolution\n- UCF101 for 16x256x256 resolution\n\n\n\n\n## Trend\n\n[![Star History Chart](https://api.star-history.com/svg?repos=CompVis/zigma\u0026type=Date)](https://star-history.com/#CompVis/zigma\u0026Date)\n\n\n## 🎫 License\n\nThis work is licensed under the Apache License, Version 2.0 (as defined in the [LICENSE](LICENSE.txt)).\n\nBy downloading and using the code and model you agree to the terms in the  [LICENSE](LICENSE.txt).\n\n[![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0)\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcompvis%2Fzigma","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fcompvis%2Fzigma","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcompvis%2Fzigma/lists"}