{"id":20936644,"url":"https://github.com/vsimkus/vae-conditional-sampling","last_synced_at":"2025-10-30T09:41:18.280Z","repository":{"id":206242987,"uuid":"679768740","full_name":"vsimkus/vae-conditional-sampling","owner":"vsimkus","description":"[TMLR] Research code for the paper \"Conditional Sampling of Variational Autoencoders via Iterated Approximate Ancestral Sampling\". ","archived":false,"fork":false,"pushed_at":"2023-11-08T14:49:07.000Z","size":68534,"stargazers_count":3,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-01-19T20:16:48.580Z","etag":null,"topics":["conditional-sampling","data-science","importance-sampling","incomplete-data","mcmc","missing-data","vae"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/vsimkus.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null}},"created_at":"2023-08-17T15:16:31.000Z","updated_at":"2025-01-13T12:02:37.000Z","dependencies_parsed_at":null,"dependency_job_id":"1b37b968-9964-41de-91db-1491a50d022b","html_url":"https://github.com/vsimkus/vae-conditional-sampling","commit_stats":null,"previous_names":["vsimkus/vae-conditional-sampling"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/vsimkus%2Fvae-conditional-sampling","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/vsimkus%2Fvae-conditional-sampling/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/vsimkus%2Fvae-conditional-sampling/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/vsimkus%2Fvae-conditional-sampling/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/vsimkus","download_url":"https://codeload.github.com/vsimkus/vae-conditional-sampling/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":243330325,"owners_count":20274039,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["conditional-sampling","data-science","importance-sampling","incomplete-data","mcmc","missing-data","vae"],"created_at":"2024-11-18T22:22:31.843Z","updated_at":"2025-10-30T09:41:18.180Z","avatar_url":"https://github.com/vsimkus.png","language":"Python","readme":"# Conditional Sampling of Variational Autoencoders via Iterated Approximate Ancestral Sampling\n\nThis repository contains the research code for\n\n\u003e Vaidotas Simkus, Michael U. Gutmann. Conditional Sampling of Variational Autoencoders via Iterated Approximate Ancestral Sampling. Transactions on Machine Learning Research, 2023.\n\nThe paper can be found here: \u003chttps://openreview.net/forum?id=I5sJ6PU6JN\u003e.\n\nThe code is shared for reproducibility purposes and is not intended for production use. It should also serve as a reference implementation for anyone wanting to use LAIR or AC-MWG for conditional sampling of VAEs (for e.g. missing data imputation using pre-trained VAEs).\n\n## Abstract\n\nConditional sampling of variational autoencoders (VAEs) is needed in various applications, such as missing data imputation, but is computationally intractable. A principled choice for asymptotically exact conditional sampling is Metropolis-within-Gibbs (MWG). However, we observe that the tendency of VAEs to learn a structured latent space, a commonly desired property, can cause the MWG sampler to get “stuck” far from the target distribution. This paper mitigates the limitations of MWG: we systematically outline the pitfalls in the context of VAEs, propose two original methods that address these pitfalls, and demonstrate an improved performance of the proposed methods on a set of sampling tasks.\n\n## Dependencies\n\nInstall python dependencies from conda and the `irwg` project package with\n\n```bash\nconda env create -f environment.yml\nconda activate irwg\npython setup.py develop\n```\n\nIf the dependencies in `environment.yml` change, update dependencies with\n\n```bash\nconda env update --file environment.yml\n```\n\n## Organisation of the code\n\n* `./irwg/data/` contains data loaders and missingness generators.\n* `./irwg/models/` contains the neural network model implementations.\n* `./irwg/sampling/` contains the code related to VAE sampling.\n  * `test_step_vae_sampling.py` contains the implementations of the methods in the paper.\n  (Note: some method names are different from the paper)\n  * LAIR is implemented in a class called `TestVAELatentAdaptiveImportanceResampling`\n  * AC-MWG is implemented in a class called `TestVAEAdaptiveCollapsedMetropolisWithinGibbs`\n* `./configs/` contains the yaml configuration files containing all the information about each experiment.\n* `./helpers/` directory contains various helper scripts for the analysis of the imputations.\n  * `compute_mnist_mog_posterior_probs.py` computes the metrics on MNIST-GMM data.\n  * `eval_large_uci_joint_imputed_dataset_divergences.py` computes the metrics on UCI data and stores into a file.\n  * `eval_omniglot_joint_imputed_dataset_fids.py` computes the metrics on Omniglot data and stores into a file.\n  * `create_marginal_vae_imputations.py` creates imputations by sampling the marginal of the VAE (i.e. unconditional imputation baseline)\n  * Configs for the helper scripts are also located in `./configs/` directory.\n* `./notebooks/` contain analysis notebooks that produce the figures in the paper, using the outputs from the helper scripts.\n\n## Running the code\n\nActivate the conda environment\n\n```bash\nconda activate irwg\n```\n\n### VAE training\n\nTo train the VAE, which we use for sampling run e.g. \n\n```bash\npython train.py --config=configs/mnist_gmm/vae_convresnet3.yaml\n```\n\n### VAE sampling\n\nThen, to sample a VAE using one of the methods run\n\n```bash\npython test.py --config=configs/mnist_gmm/samples/vae_convresnet3_k4_irwg_i1_dmis_gr_mult_replenish1_finalresample.yaml\n```\n\n### Analysis helper scripts\n\nThen, use `./helpers/compute_mnist_mog_posterior_probs.py` to compute the metrics and store them in a file, and then plot them in a notebook.\n\nSimilarly, for UCI data use `./helpers/eval_large_uci_joint_imputed_dataset_divergences.py` to compute the metrics, and then plot them in a notebook.\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fvsimkus%2Fvae-conditional-sampling","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fvsimkus%2Fvae-conditional-sampling","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fvsimkus%2Fvae-conditional-sampling/lists"}