{"id":13935705,"url":"https://github.com/EleutherAI/DALLE-mtf","last_synced_at":"2025-07-19T20:33:49.917Z","repository":{"id":52033688,"uuid":"328236233","full_name":"EleutherAI/DALLE-mtf","owner":"EleutherAI","description":"Open-AI's DALL-E for large scale training in mesh-tensorflow.","archived":false,"fork":false,"pushed_at":"2022-02-12T13:57:47.000Z","size":279,"stargazers_count":433,"open_issues_count":5,"forks_count":46,"subscribers_count":28,"default_branch":"main","last_synced_at":"2025-05-25T20:06:00.889Z","etag":null,"topics":["artificial-intelligence","autoregressive","multimodal","text-to-image","transformers","variational-autoencoder"],"latest_commit_sha":null,"homepage":"https://www.eleuther.ai/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/EleutherAI.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":"CODEOWNERS","security":null,"support":null}},"created_at":"2021-01-09T20:02:03.000Z","updated_at":"2025-01-10T00:03:48.000Z","dependencies_parsed_at":"2022-08-19T21:50:53.067Z","dependency_job_id":null,"html_url":"https://github.com/EleutherAI/DALLE-mtf","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/EleutherAI/DALLE-mtf","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2FDALLE-mtf","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2FDALLE-mtf/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2FDALLE-mtf/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2FDALLE-mtf/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/EleutherAI","download_url":"https://codeload.github.com/EleutherAI/DALLE-mtf/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2FDALLE-mtf/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":266007905,"owners_count":23863533,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","autoregressive","multimodal","text-to-image","transformers","variational-autoencoder"],"created_at":"2024-08-07T23:02:00.767Z","updated_at":"2025-07-19T20:33:49.554Z","avatar_url":"https://github.com/EleutherAI.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"# DALL-E in Mesh-Tensorflow [WIP]\n\nOpen-AI's [DALL-E](https://openai.com/blog/dall-e/) in Mesh-Tensorflow.\n\nIf this is similarly efficient to [GPT-Neo](https://github.com/EleutherAI/gpt-neo/), this repo should be able to train models up to, and larger than, the size of Open-AI's DALL-E (12B params).\n\nNo pretrained models... Yet.\n\nThanks to [Ben Wang](https://github.com/kingoflolz) for the tf vae implementation as well as getting the mtf version working, and [Aran Komatsuzaki](https://github.com/AranKomat) for help building the mtf VAE and input pipeline.\n\n# Setup\n\n```bash\ngit clone https://github.com/EleutherAI/GPTNeo\ncd GPTNeo\npip3 install -r requirements.txt\n```\n## Training Setup\n\nRuns on TPUs, untested on GPUs but should work *in theory*. \nThe example configs are designed to run on a TPU v3-32 pod.\n\nTo set up TPUs, sign up for [Google Cloud Platform](https://cloud.google.com/), and create a [storage bucket](https://cloud.google.com/storage). \n\nCreate your VM through a google shell (`https://ssh.cloud.google.com/`) with `ctpu up --vm-only` so that it can connect to your Google bucket and TPUs and setup the repo as above.\n\n## VAE pretraining\n\nDALLE needs a pretrained VAE to compress images to tokens. To run the VAE pretraining, adjust the params in `configs/vae_example.json` to a glob path pointing to a dataset of jpgs, and adjust image size to the appropriate size.\n\n```\n  \"dataset\": {\n    \"train_path\": \"gs://neo-datasets/CIFAR-10-images/train/**/*.jpg\",\n    \"eval_path\": \"gs://neo-datasets/CIFAR-10-images/test/**/*.jpg\",\n    \"image_size\": 32\n  }\n```\n\nOnce this is all set up, create your TPU, then run:\n\n```bash\npython train_vae_tf.py --tpu your_tpu_name --model vae_example\n```\n\n\nThe training logs image tensors and loss values, to check progress, you can run:\n\n```bash\ntensorboard --logdir your_model_dir\n```\n\n## Dataset Creation [DALL-E]\n\nOnce the VAE is pretrained, you can move on to DALL-E.\n\nCurrently we are training on a dummy dataset. A public, large-scale dataset for DALL-E is in the works. In the meantime, to generate some dummy data, run:\n\n```bash\npython src/data/create_tfrecords.py\n```\n\nThis should download CIFAR-10, and generate some random captions to act as text inputs.\n\nCustom datasets should be formatted in a folder, with a jsonl file in the root folder containing caption data and paths to the respective images, as follows:\n\n```\nFolder structure:\n\n        data_folder\n            jsonl_file\n            folder_1\n                img1\n                img2\n                ...\n            folder_2\n                img1\n                img2\n                ...\n            ...\n\njsonl structure:\n    {\"image_path\": folder_1/img1, \"caption\": \"some words\"}\n    {\"image_path\": folder_2/img2, \"caption\": \"more words\"}\n    ...\n```\n\nyou can then use the `create_paired_dataset` function in `src/data/create_tfrecords.py` to encode the dataset into tfrecords for use in training.\n\nOnce the dataset is created, copy it over to your bucket with gsutil:\n\n```bash\ngsutil cp -r DALLE-tfrecords gs://neo-datasets/\n```\n\nAnd finally, run training with\n\n```bash\npython train_dalle.py --tpu your_tpu_name --model dalle_example\n```\n\n## Config Guide\n\nVAE:\n\n```\n{\n  \"model_type\": \"vae\",\n  \"dataset\": {\n    \"train_path\": \"gs://neo-datasets/CIFAR-10-images/train/**/*.jpg\", # glob path to training images\n    \"eval_path\": \"gs://neo-datasets/CIFAR-10-images/test/**/*.jpg\", # glob path to eval images\n    \"image_size\": 32 # size of images (all images will be cropped / padded to this size)\n  },\n  \"train_batch_size\": 32, \n  \"eval_batch_size\": 32,\n  \"predict_batch_size\": 32,\n  \"steps_per_checkpoint\": 1000, # how often to save a checkpoint\n  \"iterations\": 500, # number of batches to infeed to the tpu at a time. Must be \u003c steps_per_checkpoint\n  \"train_steps\": 100000, # total training steps\n  \"eval_steps\": 0, # run evaluation for this many steps every steps_per_checkpoint\n  \"model_path\": \"gs://neo-models/vae_test2/\", # directory in which to save the model\n  \"mesh_shape\": \"data:16,model:2\", # mapping of processors to named dimensions - see mesh-tensorflow repo for more info\n  \"layout\": \"batch_dim:data\", # which named dimensions of the model to split across the mesh - see mesh-tensorflow repo for more info\n  \"num_tokens\": 512, # vocab size\n  \"dim\": 512, \n  \"hidden_dim\": 64, # size of hidden dim\n  \"n_channels\": 3, # number of input channels\n  \"bf_16\": false, # if true, the model is trained with bfloat16 precision\n  \"lr\": 0.001, # learning rate [by default learning rate starts at this value, then decays to 10% of this value over the course of the training]\n  \"num_layers\": 3, # number of blocks in the encoder / decoder\n  \"train_gumbel_hard\": true, # whether to use hard or soft gumbel_softmax\n  \"eval_gumbel_hard\": true\n}\n```\n\nDALL-E:\n\n```\n{\n  \"model_type\": \"dalle\",\n  \"dataset\": {\n    \"train_path\": \"gs://neo-datasets/DALLE-tfrecords/*.tfrecords\", # glob path to tfrecords data\n    \"eval_path\": \"gs://neo-datasets/DALLE-tfrecords/*.tfrecords\",\n    \"image_size\": 32 # size of images (all images will be cropped / padded to this size)\n  },\n  \"train_batch_size\": 32, # see above\n  \"eval_batch_size\": 32,\n  \"predict_batch_size\": 32,\n  \"steps_per_checkpoint\": 1000,\n  \"iterations\": 500,\n  \"train_steps\": 100000,\n  \"predict_steps\": 0,\n  \"eval_steps\": 0,\n  \"n_channels\": 3,\n  \"bf_16\": false,\n  \"lr\": 0.001,\n  \"model_path\": \"gs://neo-models/dalle_test/\",\n  \"mesh_shape\": \"data:16,model:2\",\n  \"layout\": \"batch_dim:data\",\n  \"n_embd\": 512, # size of embedding dim\n  \"text_vocab_size\": 50258, # vocabulary size of the text tokenizer\n  \"image_vocab_size\": 512, # vocabulary size of the vae - should equal num_tokens above\n  \"text_seq_len\": 256, # length of text inputs (all inputs longer / shorter will be truncated / padded)\n  \"n_layers\": 6, \n  \"n_heads\": 4, # number of attention heads. For best performance, n_embd / n_heads should equal 128\n  \"vae_model\": \"vae_example\" # path to or name of vae model config\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FEleutherAI%2FDALLE-mtf","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FEleutherAI%2FDALLE-mtf","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FEleutherAI%2FDALLE-mtf/lists"}