{"id":17714319,"url":"https://github.com/kvfrans/shortcut-models","last_synced_at":"2025-05-16T08:05:04.730Z","repository":{"id":258738750,"uuid":"866085775","full_name":"kvfrans/shortcut-models","owner":"kvfrans","description":null,"archived":false,"fork":false,"pushed_at":"2024-12-05T17:41:18.000Z","size":4702,"stargazers_count":438,"open_issues_count":9,"forks_count":11,"subscribers_count":10,"default_branch":"main","last_synced_at":"2025-04-09T08:01:40.564Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/kvfrans.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-10-01T16:13:28.000Z","updated_at":"2025-04-09T07:21:26.000Z","dependencies_parsed_at":"2024-12-22T12:12:31.565Z","dependency_job_id":"c8d69f83-88c0-462d-88d1-e4f5fc527233","html_url":"https://github.com/kvfrans/shortcut-models","commit_stats":null,"previous_names":["kvfrans/shortcut-models"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kvfrans%2Fshortcut-models","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kvfrans%2Fshortcut-models/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kvfrans%2Fshortcut-models/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kvfrans%2Fshortcut-models/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/kvfrans","download_url":"https://codeload.github.com/kvfrans/shortcut-models/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254493378,"owners_count":22080126,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-10-25T11:02:21.396Z","updated_at":"2025-05-16T08:04:59.722Z","avatar_url":"https://github.com/kvfrans.png","language":"Python","funding_links":[],"categories":["Accelerate"],"sub_categories":[],"readme":"## One-Step Diffusion via Shortcut Models \n\nKevin Frans, Danijar Hafner, Sergey Levine, Pieter Abbeel\n\n[Paper Link](https://arxiv.org/abs/2410.12557)\n[Website Link](https://kvfrans.com/shortcut-models/)\n\n### Abstract\nDiffusion models and flow-matching models have enabled generating diverse and realistic images by learning to transfer noise to data.\nHowever, sampling from these models involves iterative denoising over many neural network passes, making generation slow and expensive.\nPrevious approaches for speeding up sampling require complex training regimes, such as multiple training phases, multiple networks, or fragile scheduling.\nWe introduce shortcut models, a family of generative models that use a single network and training phase to produce high-quality samples in a single or multiple sampling steps.\nShortcut models condition the network not only on the current noise level but also on the desired step size, allowing the model to skip ahead in the generation process.\nAcross a wide range of sampling step budgets, shortcut models consistently produce higher quality samples than previous approaches, such as consistency models and reflow.\nCompared to distillation, shortcut models reduce complexity to a single network and training phase and additionally allow varying step budgets at inference time.\n\n![Showcase Figire](data/fig-showcase4.png)\n\n### Overview\n\nShortcut models can utilize standard diffusion architectures (e.g. DiT), and condition on both `t` and `d`. At `d ≈ 0`, the shortcut objective is equivalent to the flow-matching objective, and can be trained by regressing onto empirical `E[vt|xt]` samples. Targets for larger `d` shortcuts are constructed by concatenating a sequence of two `d/2` shortcuts. Both objectives can be trained jointly; shortcut models do not require a two-stage procedure or discretization schedule.\n\n![Showcase Figire](data/fig-method5.png)\n\n### Using the code\n\nThis codebase is written in JAX, and was developed on TPU-v3 machines. You should start by installing the conda dependencies from `environment.yml` and `requirements.txt`. To load datasets, we use TFDS, and you can see our specific dataloaders at [https://github.com/kvfrans/tfds_builders](https://github.com/kvfrans/tfds_builders), of course you are free to use your own dataloader as well. \n\nTo train a DiT-B scale model on CelebA:\n```\npython train.py --model.hidden_size 768 --model.patch_size 2 --model.depth 12 --model.num_heads 12 --model.mlp_ratio 4 --dataset_name celebahq256 --fid_stats data/celeba256_fidstats_ours.npz --model.cfg_scale 0 --model.class_dropout_prob 1 --model.num_classes 1 --batch_size 64 --max_steps 410_000 --model.train_type shortcut\n```\nor on Imagenet-256:\n``` \npython train.py --model.hidden_size 768 --model.patch_size 2 --model.depth 12 --model.num_heads 12 --model.mlp_ratio 4 --dataset_name imagenet256 --fid_stats data/imagenet256_fidstats_ours.npz --model.cfg_scale 1.5 --model.class_dropout_prob 0.1 --model.bootstrap_cfg 1 --batch_size 256 --max_steps 810_000 --model.train_type shortcut\n```\n\nA larger DiT-XL scale model can be trained via:\n``` \npython train.py --model.hidden_size 1152 --model.patch_size 2 --model.depth 28 --model.num_heads 16 --model.mlp_ratio 4 --dataset_name imagenet256 --fid_stats data/imagenet256_fidstats_ours.npz --model.cfg_scale 1.5 --model.class_dropout_prob 0.1 --model.bootstrap_cfg 1 --batch_size 256 --max_steps 810_000 --model.train_type shortcut\n```\n\nTo train a regular flow model instead, use `--model.train_type naive`. This code also supports `--model.sharding fsdp` for fully-sharded data parallelism, which is recommended if you are training on a multi-GPU or TPU machine.\n\n### Sanity Checking\n\nShorcut models trained with the provided functions should achieve the following FID-50k performance.\n\n|                           | 128-Step| 4-Step  | 1-Step  |\n| --------                  | ------- | ------- | ------- |\n| CelebA (DiT-B)            | 6.9     | 13.8    | 20.5    |\n| Imagenet-256 (DiT-B)      | 15.5    | 28.3    | 40.3    |\n| Imagenet-256 (DiT-XL)     | 3.8     | 7.8     | 10.6    |\n\n### Checkpoints and FID Stats\n\nPretrained model checkpoints, and pre-computed reference FID stats for CelebA and Imagenet can be downloaded from [this drive](https://drive.google.com/drive/folders/1g665i0vMxm8qqqcp5mAiexnL919-gMwW?usp=sharing). To load a checkpoint, use the `--load_dir` flag. \n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fkvfrans%2Fshortcut-models","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fkvfrans%2Fshortcut-models","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fkvfrans%2Fshortcut-models/lists"}