{"id":15601019,"url":"https://github.com/lucidrains/muse-maskgit-pytorch","last_synced_at":"2025-05-15T19:07:03.379Z","repository":{"id":65182552,"uuid":"584843932","full_name":"lucidrains/muse-maskgit-pytorch","owner":"lucidrains","description":"Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch","archived":false,"fork":false,"pushed_at":"2024-02-29T18:59:03.000Z","size":292,"stargazers_count":893,"open_issues_count":12,"forks_count":84,"subscribers_count":33,"default_branch":"main","last_synced_at":"2025-04-07T23:11:16.440Z","etag":null,"topics":["artificial-intelligence","attention-mechanisms","deep-learning","text-to-image","transformers"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-01-03T16:54:57.000Z","updated_at":"2025-04-03T09:45:42.000Z","dependencies_parsed_at":"2024-10-22T21:32:35.567Z","dependency_job_id":null,"html_url":"https://github.com/lucidrains/muse-maskgit-pytorch","commit_stats":{"total_commits":80,"total_committers":4,"mean_commits":20.0,"dds":0.08750000000000002,"last_synced_commit":"6df7f33bcd33ba28a2f682d5bd293e4f8a513e6c"},"previous_names":[],"tags_count":42,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmuse-maskgit-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmuse-maskgit-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmuse-maskgit-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmuse-maskgit-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/muse-maskgit-pytorch/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254404356,"owners_count":22065641,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","attention-mechanisms","deep-learning","text-to-image","transformers"],"created_at":"2024-10-03T02:11:56.876Z","updated_at":"2025-05-15T19:07:03.349Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003cimg src=\"./muse.png\" width=\"450px\"\u003e\u003c/img\u003e\n\n## Muse - Pytorch\n\nImplementation of \u003ca href=\"https://muse-model.github.io/\"\u003eMuse\u003c/a\u003e: Text-to-Image Generation via Masked Generative Transformers, in Pytorch\n\nPlease join \u003ca href=\"https://discord.gg/xBPBXfcFHd\"\u003e\u003cimg alt=\"Join us on Discord\" src=\"https://img.shields.io/discord/823813159592001537?color=5865F2\u0026logo=discord\u0026logoColor=white\"\u003e\u003c/a\u003e if you are interested in helping out with the replication with the \u003ca href=\"https://laion.ai/\"\u003eLAION\u003c/a\u003e community\n\n## Install\n\n```bash\n$ pip install muse-maskgit-pytorch\n```\n\n## Usage\n\nFirst train your VAE - `VQGanVAE`\n\n```python\nimport torch\nfrom muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer\n\nvae = VQGanVAE(\n    dim = 256,\n    codebook_size = 65536\n)\n\n# train on folder of images, as many images as possible\n\ntrainer = VQGanVAETrainer(\n    vae = vae,\n    image_size = 128,             # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it\n    folder = '/path/to/images',\n    batch_size = 4,\n    grad_accum_every = 8,\n    num_train_steps = 50000\n).cuda()\n\ntrainer.train()\n```\n\nThen pass the trained `VQGanVAE` and a `Transformer` to `MaskGit`\n\n```python\nimport torch\nfrom muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer\n\n# first instantiate your vae\n\nvae = VQGanVAE(\n    dim = 256,\n    codebook_size = 65536\n).cuda()\n\nvae.load('/path/to/vae.pt') # you will want to load the exponentially moving averaged VAE\n\n# then you plug the vae and transformer into your MaskGit as so\n\n# (1) create your transformer / attention network\n\ntransformer = MaskGitTransformer(\n    num_tokens = 65536,       # must be same as codebook size above\n    seq_len = 256,            # must be equivalent to fmap_size ** 2 in vae\n    dim = 512,                # model dimension\n    depth = 8,                # depth\n    dim_head = 64,            # attention head dimension\n    heads = 8,                # attention heads,\n    ff_mult = 4,              # feedforward expansion factor\n    t5_name = 't5-small',     # name of your T5\n)\n\n# (2) pass your trained VAE and the base transformer to MaskGit\n\nbase_maskgit = MaskGit(\n    vae = vae,                 # vqgan vae\n    transformer = transformer, # transformer\n    image_size = 256,          # image size\n    cond_drop_prob = 0.25,     # conditional dropout, for classifier free guidance\n).cuda()\n\n# ready your training text and images\n\ntexts = [\n    'a child screaming at finding a worm within a half-eaten apple',\n    'lizard running across the desert on two feet',\n    'waking up to a psychedelic landscape',\n    'seashells sparkling in the shallow waters'\n]\n\nimages = torch.randn(4, 3, 256, 256).cuda()\n\n# feed it into your maskgit instance, with return_loss set to True\n\nloss = base_maskgit(\n    images,\n    texts = texts\n)\n\nloss.backward()\n\n# do this for a long time on much data\n# then...\n\nimages = base_maskgit.generate(texts = [\n    'a whale breaching from afar',\n    'young girl blowing out candles on her birthday cake',\n    'fireworks with blue and green sparkles'\n], cond_scale = 3.) # conditioning scale for classifier free guidance\n\nimages.shape # (3, 3, 256, 256)\n```\n\n\nTo train the super-resolution maskgit requires you to change 1 field on `MaskGit` instantiation (you will need to now pass in the `cond_image_size`, as the previous image size being conditioned on)\n\nOptionally, you can pass in a different `VAE` as `cond_vae` for the conditioning low-resolution image. By default it will use the `vae` for both tokenizing the super and low resoluted images.\n\n```python\nimport torch\nimport torch.nn.functional as F\nfrom muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer\n\n# first instantiate your ViT VQGan VAE\n# a VQGan VAE made of transformers\n\nvae = VQGanVAE(\n    dim = 256,\n    codebook_size = 65536\n).cuda()\n\nvae.load('./path/to/vae.pt') # you will want to load the exponentially moving averaged VAE\n\n# then you plug the VqGan VAE into your MaskGit as so\n\n# (1) create your transformer / attention network\n\ntransformer = MaskGitTransformer(\n    num_tokens = 65536,       # must be same as codebook size above\n    seq_len = 1024,           # must be equivalent to fmap_size ** 2 in vae\n    dim = 512,                # model dimension\n    depth = 2,                # depth\n    dim_head = 64,            # attention head dimension\n    heads = 8,                # attention heads,\n    ff_mult = 4,              # feedforward expansion factor\n    t5_name = 't5-small',     # name of your T5\n)\n\n# (2) pass your trained VAE and the base transformer to MaskGit\n\nsuperres_maskgit = MaskGit(\n    vae = vae,\n    transformer = transformer,\n    cond_drop_prob = 0.25,\n    image_size = 512,                     # larger image size\n    cond_image_size = 256,                # conditioning image size \u003c- this must be set\n).cuda()\n\n# ready your training text and images\n\ntexts = [\n    'a child screaming at finding a worm within a half-eaten apple',\n    'lizard running across the desert on two feet',\n    'waking up to a psychedelic landscape',\n    'seashells sparkling in the shallow waters'\n]\n\nimages = torch.randn(4, 3, 512, 512).cuda()\n\n# feed it into your maskgit instance, with return_loss set to True\n\nloss = superres_maskgit(\n    images,\n    texts = texts\n)\n\nloss.backward()\n\n# do this for a long time on much data\n# then...\n\nimages = superres_maskgit.generate(\n    texts = [\n        'a whale breaching from afar',\n        'young girl blowing out candles on her birthday cake',\n        'fireworks with blue and green sparkles',\n        'waking up to a psychedelic landscape'\n    ],\n    cond_images = F.interpolate(images, 256),  # conditioning images must be passed in for generating from superres\n    cond_scale = 3.\n)\n\nimages.shape # (4, 3, 512, 512)\n```\n\nAll together now\n\n```python\nfrom muse_maskgit_pytorch import Muse\n\nbase_maskgit.load('./path/to/base.pt')\n\nsuperres_maskgit.load('./path/to/superres.pt')\n\n# pass in the trained base_maskgit and superres_maskgit from above\n\nmuse = Muse(\n    base = base_maskgit,\n    superres = superres_maskgit\n)\n\nimages = muse([\n    'a whale breaching from afar',\n    'young girl blowing out candles on her birthday cake',\n    'fireworks with blue and green sparkles',\n    'waking up to a psychedelic landscape'\n])\n\nimages # List[PIL.Image.Image]\n```\n\n## Appreciation\n\n- \u003ca href=\"https://stability.ai/\"\u003eStabilityAI\u003c/a\u003e for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.\n\n- \u003ca href=\"https://huggingface.co/\"\u003e🤗 Huggingface\u003c/a\u003e for the transformers and accelerate library, both which are wonderful\n\n## Todo\n\n- [x] test end-to-end\n- [x] separate cond_images_or_ids, it is not done right\n- [x] add training code for vae\n- [x] add optional self-conditioning on embeddings\n- [x] combine with token critic paper, already implemented at \u003ca href=\"https://github.com/lucidrains/phenaki-pytorch\"\u003ePhenaki\u003c/a\u003e\n\n- [ ] hook up accelerate training code for maskgit\n\n## Citations\n\n```bibtex\n@inproceedings{Chang2023MuseTG,\n    title   = {Muse: Text-To-Image Generation via Masked Generative Transformers},\n    author  = {Huiwen Chang and Han Zhang and Jarred Barber and AJ Maschinot and Jos{\\'e} Lezama and Lu Jiang and Ming-Hsuan Yang and Kevin P. Murphy and William T. Freeman and Michael Rubinstein and Yuanzhen Li and Dilip Krishnan},\n    year    = {2023}\n}\n```\n\n```bibtex\n@article{Chen2022AnalogBG,\n    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},\n    author  = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton},\n    journal = {ArXiv},\n    year    = {2022},\n    volume  = {abs/2208.04202}\n}\n```\n\n```bibtex\n@misc{jabri2022scalable,\n    title   = {Scalable Adaptive Computation for Iterative Generation},\n    author  = {Allan Jabri and David Fleet and Ting Chen},\n    year    = {2022},\n    eprint  = {2212.11972},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.LG}\n}\n```\n\n```bibtex\n@article{Lezama2022ImprovedMI,\n    title   = {Improved Masked Image Generation with Token-Critic},\n    author  = {Jos{\\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},\n    journal = {ArXiv},\n    year    = {2022},\n    volume  = {abs/2209.04439}\n}\n```\n\n```bibtex\n@inproceedings{Nijkamp2021SCRIPTSP,\n    title   = {SCRIPT: Self-Critic PreTraining of Transformers},\n    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},\n    booktitle = {North American Chapter of the Association for Computational Linguistics},\n    year    = {2021}\n}\n```\n\n```bibtex\n@inproceedings{dao2022flashattention,\n    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},\n    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\\'e}, Christopher},\n    booktitle = {Advances in Neural Information Processing Systems},\n    year    = {2022}\n}\n```\n\n```bibtex\n@misc{mentzer2023finite,\n    title   = {Finite Scalar Quantization: VQ-VAE Made Simple},\n    author  = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},\n    year    = {2023},\n    eprint  = {2309.15505},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CV}\n}\n```\n\n```bibtex\n@misc{yu2023language,\n    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},\n    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},\n    year    = {2023},\n    eprint  = {2310.05737},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CV}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fmuse-maskgit-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2Fmuse-maskgit-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fmuse-maskgit-pytorch/lists"}