{"id":18843455,"url":"https://github.com/ai-hypercomputer/maxdiffusion","last_synced_at":"2025-04-13T02:17:11.603Z","repository":{"id":222892269,"uuid":"693349437","full_name":"AI-Hypercomputer/maxdiffusion","owner":"AI-Hypercomputer","description":null,"archived":false,"fork":false,"pushed_at":"2025-04-10T21:53:51.000Z","size":59206,"stargazers_count":199,"open_issues_count":14,"forks_count":24,"subscribers_count":14,"default_branch":"main","last_synced_at":"2025-04-13T02:17:05.630Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/AI-Hypercomputer.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"docs/contributing.md","funding":null,"license":"LICENSE","code_of_conduct":"docs/code-of-conduct.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-09-18T21:16:15.000Z","updated_at":"2025-04-10T06:50:20.000Z","dependencies_parsed_at":"2024-03-01T01:29:49.407Z","dependency_job_id":"4ee793d5-12dd-43fe-bf67-3384df68c1bb","html_url":"https://github.com/AI-Hypercomputer/maxdiffusion","commit_stats":{"total_commits":83,"total_committers":9,"mean_commits":9.222222222222221,"dds":0.6265060240963856,"last_synced_commit":"f4b904248a49cc8b5e114cf996be2e817bc862f8"},"previous_names":["google/maxdiffusion","ai-hypercomputer/maxdiffusion"],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fmaxdiffusion","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fmaxdiffusion/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fmaxdiffusion/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fmaxdiffusion/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/AI-Hypercomputer","download_url":"https://codeload.github.com/AI-Hypercomputer/maxdiffusion/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248654104,"owners_count":21140237,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-08T02:57:50.950Z","updated_at":"2025-04-13T02:17:11.581Z","avatar_url":"https://github.com/AI-Hypercomputer.png","language":"Python","readme":"\u003c!--\n Copyright 2024 Google LLC\n\n Licensed under the Apache License, Version 2.0 (the \"License\");\n you may not use this file except in compliance with the License.\n You may obtain a copy of the License at\n\n      https://www.apache.org/licenses/LICENSE-2.0\n\n Unless required by applicable law or agreed to in writing, software\n distributed under the License is distributed on an \"AS IS\" BASIS,\n WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n See the License for the specific language governing permissions and\n limitations under the License.\n --\u003e\n\n[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)\n\n# What's new?\n- **`2025/02/12`**: Flux LoRA for inference.\n- **`2025/02/08`**: Flux schnell \u0026 dev inference.\n- **`2024/12/12`**: Load multiple LoRAs for inference.\n- **`2024/10/22`**: LoRA support for Hyper SDXL.\n- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.\n- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.\n\n# Overview\n\nMaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python/Jax that run on XLA devices including Cloud TPUs and GPUs. MaxDiffusion aims to be a launching off point for ambitious Diffusion projects both in research and production. We encourage you to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet your needs.\n\nThe goal of this project is to provide reference implementations for latent diffusion models that help developers get started with training, tuning, and serving solutions on XLA devices including Cloud TPUs and GPUs. We started with Stable Diffusion inference on TPUs, but welcome code contributions to grow.\n\nMaxDiffusion supports\n* Stable Diffusion 2 base (training and inference)\n* Stable Diffusion 2.1 (training and inference)\n* Stable Diffusion XL (training and inference).\n* Stable Diffusion Lightning (inference).\n* Hyper-SD XL LoRA loading (inference).\n* Load Multiple LoRA (SDXL inference).\n* ControlNet inference (Stable Diffusion 1.4 \u0026 SDXL).\n* Dreambooth training support for Stable Diffusion 1.x,2.x.\n\n**WARNING: The training code is purely experimental and is under development.**\n\n# Table of Contents\n\n- [What's new?](#whats-new)\n- [Overview](#overview)\n- [Table of Contents](#table-of-contents)\n- [Getting Started](#getting-started)\n  - [Getting Started:](#getting-started-1)\n  - [Training](#training)\n  - [Dreambooth](#dreambooth)\n  - [Inference](#inference)\n  - [Flux](#flux)\n    - [Fused Attention for GPU:](#fused-attention-for-gpu)\n  - [Hyper SDXL LoRA](#hyper-sdxl-lora)\n  - [Load Multiple LoRA](#load-multiple-lora)\n  - [SDXL Lightning](#sdxl-lightning)\n  - [ControlNet](#controlnet)\n  - [Getting Started: Multihost development](#getting-started-multihost-development)\n- [Comparison to Alternatives](#comparison-to-alternatives)\n- [Development](#development)\n\n# Getting Started\n\nWe recommend starting with a single TPU host and then moving to multihost.\n\nMinimum requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow \u003e= 2.12.0.\n\n## Getting Started:\n\nFor your first time running Maxdiffusion, we provide specific [instructions](docs/getting_started/first_run.md).\n\n## Training\n\nAfter installation completes, run the training script.\n\n- **Stable Diffusion XL**\n\n  ```bash\n  export LIBTPU_INIT_ARGS=\"\"\n  python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name=\"my_xl_run\" output_dir=\"gs://your-bucket/\" per_device_batch_size=1\n  ```\n\n  On GPUS with Fused Attention:\n\n  First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).\n\n  ```bash\n  NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention=\"cudnn_flash_te\" jit_initializers=False\n  ```\n\n  To generate images with a trained checkpoint, run:\n\n  ```bash\n  python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_xl.yml run_name=\"my_run\" pretrained_model_name_or_path=\u003cyour_saved_checkpoint_path\u003e from_pt=False attention=dot_product\n  ```\n\n- **Stable Diffusion 2 base**\n\n  ```bash\n  export LIBTPU_INIT_ARGS=\"\"\n  python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=\"my_run\" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash\n  ```\n\n- **Stable Diffusion 1.4**\n\n  ```bash\n  export LIBTPU_INIT_ARGS=\"\"\n  python -m src.maxdiffusion.train src/maxdiffusion/configs/base14.yml run_name=\"my_run\" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash\n  ```\n\n  To generate images with a trained checkpoint, run:\n\n  ```bash\n  python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml run_name=\"my_run\" output_dir=gs://your-bucket/ from_pt=False attention=dot_product\n  ```\n\n  ## Dreambooth\n\n  **Stable Diffusion 1.x,2.x**\n\n  ```bash\n  python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base14.yml class_data_dir=\u003cyour-class-dir\u003e instance_data_dir=\u003cyour-instance-dir\u003e instance_prompt=\"a photo of ohwx dog\" class_prompt=\"photo of a dog\" max_train_steps=150 jax_cache_dir=\u003cyour-cache-dir\u003e class_prompt=\"a photo of a dog\" activations_dtype=bfloat16 weights_dtype=float32 per_device_batch_size=1 enable_profiler=False precision=DEFAULT cache_dreambooth_dataset=False learning_rate=4e-6 num_class_images=100 run_name=\u003cyour-run-name\u003e output_dir=gs://\u003cyour-bucket-name\u003e\n  ```\n\n## Inference\n\nTo generate images, run the following command:\n- **Stable Diffusion XL**\n\n  Single and Multi host inference is supported with sharding annotations:\n\n  ```bash\n  python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name=\"my_run\"\n  ```\n\n  Single host pmap version:\n\n  ```bash\n  python -m src.maxdiffusion.generate_sdxl_replicated\n  ```\n\n- **Stable Diffusion 2 base**\n  ```bash\n  python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml run_name=\"my_run\"\n\n- **Stable Diffusion 2.1**\n  ```bash\n  python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name=\"my_run\"\n  ```\n  ## Flux\n\n  First make sure you have permissions to access the Flux repos in Huggingface.\n\n  Expected results on 1024 x 1024 images with flash attention and bfloat16:\n\n  | Model | Accelerator | Sharding Strategy | Batch Size | Steps | time (secs) |\n  | --- | --- | --- | --- | --- | --- |\n  | Flux-dev | v4-8 | DDP | 4 | 28 | 23 |\n  | Flux-schnell | v4-8 | DDP | 4 | 4 | 2.2 |\n  | Flux-dev | v6e-4 | DDP | 4 | 28 | 5.5 |\n  | Flux-schnell | v6e-4 | DDP | 4 | 4 | 0.8 |\n  | Flux-schnell | v6e-4 | FSDP | 4 | 4 | 1.2 |\n\n  Schnell:\n\n  ```bash\n  python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt=\"photograph of an electronics chip in the shape of a race car with trillium written on its side\" per_device_batch_size=1\n  ```\n\n  Dev:\n\n  ```bash\n  python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt=\"photograph of an electronics chip in the shape of a race car with trillium written on its side\" per_device_batch_size=1\n  ```\n\n  If you are using a TPU v6e (Trillium), you can use optimized flash block sizes for faster inference. Uncomment Flux-dev [config](src/maxdiffusion/configs/base_flux_dev.yml#60) and Flux-schnell [config](src/maxdiffusion/configs/base_flux_schnell.yml#68)\n\n  To keep text encoders, vae and transformer on HBM memory at all times, the following command shards the model across devices. \n\n  ```bash\n  python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt=\"photograph of an electronics chip in the shape of a race car with trillium written on its side\" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False\n  ```\n\n    ## Fused Attention for GPU:\n    Fused Attention for GPU is supported via TransformerEngine. Installation instructions:\n\n    ```bash\n    cd maxdiffusion\n    pip install -U \"jax[cuda12]\"\n    pip install -r requirements.txt\n    pip install --upgrade torch torchvision\n    pip install \"transformer_engine[jax]\n    pip install .\n    ```\n\n    Now run the command:\n\n    ```bash\n    NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention=\"cudnn_flash_te\" hardware=gpu\n    ```\n\n    ## Flux LoRA\n\n    Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.\n\n    Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection.\n\n    First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows:\n\n    ```bash\n    python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{\"lora_model_name_or_path\" : [\"/home/jfacevedo/anime_lora.safetensors\"], \"weight_name\" : [\"anime_lora.safetensors\"], \"adapter_name\" : [\"anime\"], \"scale\": [0.8], \"from_pt\": [\"true\"]}'\n    ```\n\n    Loading multiple LoRAs is supported as follows:\n\n    ```bash\n    python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{\"lora_model_name_or_path\" : [\"/home/jfacevedo/anime_lora.safetensors\", \"/home/jfacevedo/amateurphoto-v6-forcu.safetensors\"], \"weight_name\" : [\"anime_lora.safetensors\",\"amateurphoto-v6-forcu.safetensors\"], \"adapter_name\" : [\"anime\",\"realistic\"], \"scale\": [0.6, 0.6], \"from_pt\": [\"true\",\"true\"]}'\n    ```\n\n  ## Hyper SDXL LoRA\n\n  Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)\n\n  ```bash\n  python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name=\"test-lora\" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt=\"a photograph of a cat wearing a hat riding a skateboard in a park.\" per_device_batch_size=1 pretrained_model_name_or_path=\"Lykon/AAM_XL_AnimeMix\" from_pt=True revision=main diffusion_scheduler_config='{\"_class_name\" : \"FlaxDDIMScheduler\", \"timestep_spacing\" : \"trailing\"}' lora_config='{\"lora_model_name_or_path\" : [\"ByteDance/Hyper-SD\"], \"weight_name\" : [\"Hyper-SDXL-2steps-lora.safetensors\"], \"adapter_name\" : [\"hyper-sdxl\"], \"scale\": [0.7], \"from_pt\": [\"true\"]}'\n  ```\n\n  ## Load Multiple LoRA\n\n    Supports loading multiple LoRAs for inference. Both from local or from HuggingFace hub.\n\n    ```bash\n    python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name=\"test-lora\" output_dir=/tmp/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=30 do_classifier_free_guidance=True prompt=\"ultra detailed diagram blueprint of a papercut Sitting MaineCoon cat, wide canvas, ampereart, electrical diagram, bl3uprint, papercut\" per_device_batch_size=1 diffusion_scheduler_config='{\"_class_name\" : \"FlaxDDIMScheduler\", \"timestep_spacing\" : \"trailing\"}' lora_config='{\"lora_model_name_or_path\" : [\"/home/jfacevedo/blueprintify-sd-xl-10.safetensors\",\"TheLastBen/Papercut_SDXL\"], \"weight_name\" : [\"/home/jfacevedo/blueprintify-sd-xl-10.safetensors\",\"papercut.safetensors\"], \"adapter_name\" : [\"blueprint\",\"papercut\"], \"scale\": [0.8, 0.7], \"from_pt\": [\"true\", \"true\"]}'\n    ```\n\n  ## SDXL Lightning\n\n  Single and Multi host inference is supported with sharding annotations:\n\n    ```bash\n    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl_lightning.yml run_name=\"my_run\" lightning_repo=\"ByteDance/SDXL-Lightning\" lightning_ckpt=\"sdxl_lightning_4step_unet.safetensors\"\n    ```\n\n  ## ControlNet\n\n  Might require installing extra libraries for opencv: `apt-get update \u0026\u0026 apt-get install ffmpeg libsm6 libxext6  -y`\n\n  - Stable Diffusion 1.4\n\n    ```bash\n    python src/maxdiffusion/controlnet/generate_controlnet_replicated.py\n    ```\n\n  - Stable Diffusion XL\n\n    ```bash\n    python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py\n    ```\n\n\n## Getting Started: Multihost development\nMultihost training for Stable Diffusion 2 base can be run using the following command:\n```bash\nTPU_NAME=\u003cyour-tpu-name\u003e\nZONE=\u003cyour-zone\u003e\nPROJECT_ID=\u003cyour-project-id\u003e\ngcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --project $PROJECT_ID --worker=all --command=\"\nexport LIBTPU_INIT_ARGS=\"\"\ngit clone https://github.com/google/maxdiffusion\ncd maxdiffusion\npip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\npip3 install -r requirements.txt\npip3 install .\npython -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run output_dir=gs://your-bucket/\"\n```\n\n# Comparison to Alternatives\n\nMaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/diffusers), a Hugging Face diffusion library written in Python, Pytorch and Jax. MaxDiffusion is compatible with Hugging Face Jax models. MaxDiffusion is more complex and was designed to run distributed across TPU Pods.\n\n# Development\n\nWhether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, a full suite of tests can be found in `tests` and `src/maxdiffusion/tests`.\n\nTo run unit tests simply run:\n```\npython -m pytest\n```\n\nThis project uses `pylint` and `pyink` to enforce code style. Before submitting a pull request, please ensure your code passes these checks by running:\n\n```\nbash code_style.sh\n```\n\nThis script will automatically format your code with `pyink` and help you identify any remaining style issues.\n\n\nThe full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fai-hypercomputer%2Fmaxdiffusion","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fai-hypercomputer%2Fmaxdiffusion","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fai-hypercomputer%2Fmaxdiffusion/lists"}