Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/zhang-liyi/tsc
Transport Score Climbing: Variational Inference using Inclusive KL and Adaptive Neural Transport
https://github.com/zhang-liyi/tsc
Last synced: 14 days ago
JSON representation
Transport Score Climbing: Variational Inference using Inclusive KL and Adaptive Neural Transport
- Host: GitHub
- URL: https://github.com/zhang-liyi/tsc
- Owner: zhang-liyi
- Created: 2021-06-15T11:44:05.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2022-08-01T18:25:25.000Z (over 2 years ago)
- Last Synced: 2024-08-01T16:46:23.165Z (3 months ago)
- Language: Jupyter Notebook
- Homepage:
- Size: 6.1 MB
- Stars: 3
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Transport Score Climbing: Variational Inference using Inclusive KL and Adaptive Neural Transport
Packages used:
* TensorFlow 2.3.0
* TensorFlow-Probability 0.11.1
* TensorFlow Addons
* TensorFlow DatasetsTo run:
`python main.py --dataset=[funnel/banana/mnist/mnist_dyn/cifar10] --method=[vi_klqp, vi_klpq, vae, vae_mcmc] --v_fam=[gaussian/flow] --space=[original/warped] --num_samp=xxx --epochs=xxx --lr=xxx --decay_rate=xxx --hmc_e=xxx --hmc_L=xxx --hmc_L_cap=xxx --cis=xxx --reinitialize_from_q=[true/false] --warm_up=[true/false]`Some explanations:
* `--dataset=[mnist/mnist_dyn/cifar10]` must correspond with the `vae_xxx` methods; `--dataset=[funnel/banana]` must correspond with the `vi_klxx` methods.
* `--method`: `vae` includes VAE and IWAE; `vae_mcmc` includes CIS-MSC, NeutraHMC, and TSC. Use `space` argument accordingly.
* `--num_samp` refers to number of samples used in VI. It defaults at 1, and when > 1, does IWAE for VAE experiments.
* `--decay_rate` refers to decay rate in inverse time decay learning schedule, which is only used in non-VAE experiments.
* `--cis` defaults at 0. For VAE-related methods, if `cis` is > 0, then the program does CIS-MSC with `cis` as number of importance samples.
* `--reinitialize_from_q=[true/false]` refers to whether we reinitialize HMC chain from q in every epoch in VAE experiments. 'True' means NeutraHMC; 'false' means TSC.To run survey dataset, one should use `survey_data.ipynb`.