{"id":13699536,"url":"https://github.com/jaanli/variational-autoencoder","last_synced_at":"2025-04-08T11:12:24.435Z","repository":{"id":45820391,"uuid":"65744394","full_name":"jaanli/variational-autoencoder","owner":"jaanli","description":"Variational autoencoder implemented in tensorflow and pytorch (including inverse autoregressive flow) ","archived":false,"fork":false,"pushed_at":"2024-04-24T13:45:21.000Z","size":97,"stargazers_count":1162,"open_issues_count":2,"forks_count":258,"subscribers_count":32,"default_branch":"master","last_synced_at":"2025-04-01T09:34:23.711Z","etag":null,"topics":["autoregressive-neural-networks","deep","deep-learning","deep-neural-networks","learning","machine-learning","probabilistic-graphical-models","pytorch","tensorflow","unsupervised-learning","vae","variational-autoencoder","variational-inference"],"latest_commit_sha":null,"homepage":"https://jaan.io/what-is-variational-autoencoder-vae-tutorial/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/jaanli.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2016-08-15T15:48:05.000Z","updated_at":"2025-03-29T17:28:57.000Z","dependencies_parsed_at":"2024-09-30T22:00:57.394Z","dependency_job_id":"6e998a9e-708d-47bc-bbfc-45e30d22868d","html_url":"https://github.com/jaanli/variational-autoencoder","commit_stats":{"total_commits":50,"total_committers":7,"mean_commits":7.142857142857143,"dds":"0.33999999999999997","last_synced_commit":"2dd2a786edf8018ffe7bf5c2d5713e802cfbc198"},"previous_names":["altosaar/variational-autoencoder"],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaanli%2Fvariational-autoencoder","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaanli%2Fvariational-autoencoder/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaanli%2Fvariational-autoencoder/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jaanli%2Fvariational-autoencoder/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/jaanli","download_url":"https://codeload.github.com/jaanli/variational-autoencoder/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247829511,"owners_count":21002997,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["autoregressive-neural-networks","deep","deep-learning","deep-neural-networks","learning","machine-learning","probabilistic-graphical-models","pytorch","tensorflow","unsupervised-learning","vae","variational-autoencoder","variational-inference"],"created_at":"2024-08-02T20:00:35.861Z","updated_at":"2025-04-08T11:12:24.416Z","avatar_url":"https://github.com/jaanli.png","language":"Python","funding_links":[],"categories":["\u003cspan id=\"head50\"\u003e3.6. Probablistic Machine Learning and Deep Learning\u003c/span\u003e","Python"],"sub_categories":["\u003cspan id=\"head53\"\u003e3.6.3. VAE\u003c/span\u003e"],"readme":"# Variational Autoencoder in tensorflow and pytorch\n[![DOI](https://zenodo.org/badge/65744394.svg)](https://zenodo.org/badge/latestdoi/65744394)\n\nReference implementation for a variational autoencoder in TensorFlow and PyTorch.\n\nI recommend the PyTorch version. It includes an example of a more expressive variational family, the [inverse autoregressive flow](https://arxiv.org/abs/1606.04934).\n\nVariational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder).\n\nBlog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/\n\n\n## PyTorch implementation\n\n(anaconda environment is in `environment-jax.yml`)\n\nImportance sampling is used to estimate the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. The final marginal likelihood on the test set was `-97.10` nats is comparable to published numbers.\n\n```\n$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000\nStep 0          Train ELBO estimate: -558.027   Validation ELBO estimate: -384.432      Validation log p(x) estimate: -355.430  Speed: 2.72e+06 examples/s\nStep 10000      Train ELBO estimate: -111.323   Validation ELBO estimate: -109.048      Validation log p(x) estimate: -103.746  Speed: 2.64e+04 examples/s\nStep 20000      Train ELBO estimate: -103.013   Validation ELBO estimate: -107.655      Validation log p(x) estimate: -101.275  Speed: 2.63e+04 examples/s\nStep 29999      Test ELBO estimate: -106.642    Test log p(x) estimate: -100.309\nTotal time: 2.49 minutes\n```\n\n\nUsing a non mean-field, more expressive variational posterior approximation (inverse autoregressive flow, https://arxiv.org/abs/1606.04934), the test marginal log-likelihood improves to `-95.33` nats:\n\n```\n$ python train_variational_autoencoder_pytorch.py --variational flow\nstep:   0       train elbo: -578.35\nstep:   0               valid elbo: -407.06     valid log p(x): -367.88\nstep:   10000   train elbo: -106.63\nstep:   10000           valid elbo: -110.12     valid log p(x): -104.00\nstep:   20000   train elbo: -101.51\nstep:   20000           valid elbo: -105.02     valid log p(x): -99.11\nstep:   30000   train elbo: -98.70\nstep:   30000           valid elbo: -103.76     valid log p(x): -97.71\n```\n\n## jax implementation\n\nUsing jax (anaconda environment is in `environment-jax.yml`), to get a 3x speedup over pytorch:\n```\n$ python train_variational_autoencoder_jax.py --variational mean-field \nStep 0          Train ELBO estimate: -566.059   Validation ELBO estimate: -565.755      Validation log p(x) estimate: -557.914  Speed: 2.56e+11 examples/s\nStep 10000      Train ELBO estimate: -98.560    Validation ELBO estimate: -105.725      Validation log p(x) estimate: -98.973   Speed: 7.03e+04 examples/s\nStep 20000      Train ELBO estimate: -109.794   Validation ELBO estimate: -105.756      Validation log p(x) estimate: -97.914   Speed: 4.26e+04 examples/s\nStep 29999      Test ELBO estimate: -104.867    Test log p(x) estimate: -96.716\nTotal time: 0.810 minutes\n```\n\nInverse autoregressive flow in jax:\n```\n$ python train_variational_autoencoder_jax.py --variational flow \nStep 0          Train ELBO estimate: -727.404   Validation ELBO estimate: -726.977      Validation log p(x) estimate: -713.389  Speed: 2.56e+11 examples/s\nStep 10000      Train ELBO estimate: -100.093   Validation ELBO estimate: -106.985      Validation log p(x) estimate: -99.565   Speed: 2.57e+04 examples/s\nStep 20000      Train ELBO estimate: -113.073   Validation ELBO estimate: -108.057      Validation log p(x) estimate: -98.841   Speed: 3.37e+04 examples/s\nStep 29999      Test ELBO estimate: -106.803    Test log p(x) estimate: -97.620\nTotal time: 2.350 minutes\n```\n\n(The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used in https://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.)\n\n# Generating the GIFs\n\n1. Run `python train_variational_autoencoder_tensorflow.py`\n2. Install imagemagick (homebrew for Mac: https://formulae.brew.sh/formula/imagemagick or Chocolatey in Windows: https://community.chocolatey.org/packages/imagemagick.app)\n3. Go to the directory where the jpg files are saved, and run the imagemagick command to generate the .gif: `convert -delay 20 -loop 0 *.jpg latent-space.gif`\n4. \n\n## TODO (help needed - feel free to send a PR!)\n- add multiple GPU / TPU option\n- add jaxtyping support for PyTorch and Jax implementations :) for runtime static type checking (using @beartype decorators)\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjaanli%2Fvariational-autoencoder","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjaanli%2Fvariational-autoencoder","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjaanli%2Fvariational-autoencoder/lists"}