{"id":13869993,"url":"https://github.com/didriknielsen/survae_flows","last_synced_at":"2025-07-15T20:30:58.134Z","repository":{"id":43797172,"uuid":"276901631","full_name":"didriknielsen/survae_flows","owner":"didriknielsen","description":"Code for paper \"SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows\"","archived":false,"fork":false,"pushed_at":"2021-02-01T10:33:24.000Z","size":4219,"stargazers_count":283,"open_issues_count":9,"forks_count":34,"subscribers_count":28,"default_branch":"master","last_synced_at":"2024-08-06T21:22:55.493Z","etag":null,"topics":["flows","survae-flows"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/didriknielsen.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2020-07-03T13:04:08.000Z","updated_at":"2024-06-25T14:38:52.000Z","dependencies_parsed_at":"2022-09-24T09:40:49.553Z","dependency_job_id":null,"html_url":"https://github.com/didriknielsen/survae_flows","commit_stats":null,"previous_names":[],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/didriknielsen%2Fsurvae_flows","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/didriknielsen%2Fsurvae_flows/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/didriknielsen%2Fsurvae_flows/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/didriknielsen%2Fsurvae_flows/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/didriknielsen","download_url":"https://codeload.github.com/didriknielsen/survae_flows/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":226068132,"owners_count":17568703,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["flows","survae-flows"],"created_at":"2024-08-05T20:01:24.599Z","updated_at":"2024-11-23T16:30:56.924Z","avatar_url":"https://github.com/didriknielsen.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"# SurVAE Flows\n\n\u003e Official code for [SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows](https://arxiv.org/abs/2007.02731)  \nby Didrik Nielsen, Priyank Jaini, Emiel Hoogeboom, Ole Winther, Max Welling.\n\nSurVAE Flows is a framework of composable transformations that extends the framework of normalizing flows.  \nSurVAE Flows make use of not only **bijective** transformations, but also **surjective** and **stochastic** transformations.  \nFor more details, see the [paper](https://arxiv.org/abs/2007.02731) or check out this [talk](https://www.youtube.com/watch?v=bXp8fk4MRXQ) by Max Welling.\n\n\u003cimg src=\"assets/illustrations/transforms_fig.png\" width=\"800\"\u003e  \n\n\u003c!-- Composable building blocks of SurVAE flows include:  \n* **Bijective:** Invertible deterministic transformations. The usual building blocks in normalizing flows.\n* **Stochastic:** Stochastic transformations with stochastic inverses. VAEs are an important example.\n* **Surjective (Gen.):** Deterministic transformations in the generative direction with a stochastic right-inverse in the inference direction.\n* **Surjective (Inf.):** Deterministic transformations in the inference direction with a stochastic right-inverse in the generative direction.\n\n\nFor more details, see [the paper](https://arxiv.org/abs/2007.02731) or check out this talk by Max Welling:  \n[![Talk](https://img.youtube.com/vi/bXp8fk4MRXQ/0.jpg)](https://www.youtube.com/watch?v=bXp8fk4MRXQ) --\u003e\n\n## Contents\n\n* `/survae/`: Code for the SurVAE library. See description below.\n* `/examples/`: Runnable examples using the SurVAE library.\n* `/experiments/`: Code to reproduce the experiments in the paper.\n\n**Pretrained models** can be downloaded from [releases](https://github.com/didriknielsen/survae_flows/releases/tag/v1.0.0).\n\n## The SurVAE Library\n\n\nThe SurVAE library is a Python package, built on top of [PyTorch](https://pytorch.org/).  \nThe SurVAE library allows straightforward construction of SurVAE flows.\n\n#### Installation\n\nIn the folder containing `setup.py`, run\n```\npip install .\n```\n\n#### Example 1: Normalizing Flow\n\nWe can construct a simple *normalizing flow* by stacking **bijective transformations**.  \nIn this case, we model 2d data using a flow of 4 affine coupling layers.\n\n```python\nimport torch.nn as nn\nfrom survae.flows import Flow\nfrom survae.distributions import StandardNormal\nfrom survae.transforms import AffineCouplingBijection, ActNormBijection, Reverse\nfrom survae.nn.layers import ElementwiseParams\n\ndef net():\n  return nn.Sequential(nn.Linear(1, 200), nn.ReLU(),\n                       nn.Linear(200, 100), nn.ReLU(),\n                       nn.Linear(100, 2), ElementwiseParams(2))\n\nmodel = Flow(base_dist=StandardNormal((2,)),\n             transforms=[\n               AffineCouplingBijection(net()), ActNormBijection(2), Reverse(2),\n               AffineCouplingBijection(net()), ActNormBijection(2), Reverse(2),\n               AffineCouplingBijection(net()), ActNormBijection(2), Reverse(2),\n               AffineCouplingBijection(net()), ActNormBijection(2),\n             ])\n```\nSee [here](https://github.com/didriknielsen/survae_flows/blob/master/examples/toy_flow.py) for a runnable example.\n\n#### Example 2: VAE\n\nWe can further build *VAEs* using **stochastic transformations**.  \nWe here construct a simple VAE for binary images of shape (1,28,28), such as binarized MNIST.  \nWe can easily extend this simple VAE by adding more layers to obtain e.g. hierarchical VAEs or VAEs with flow priors.  \nWe can also use conditional flows in the encoder and/or decoder to obtain a more expressive VAE transformation.\n\n```python\nfrom survae.flows import Flow\nfrom survae.transforms import VAE\nfrom survae.distributions import StandardNormal, ConditionalNormal, ConditionalBernoulli\nfrom survae.nn.nets import MLP\n\nencoder = ConditionalNormal(MLP(784, 2*latent_size,\n                                hidden_units=[512,256],\n                                activation='relu',\n                                in_lambda=lambda x: 2 * x.view(x.shape[0], 784).float() - 1))\ndecoder = ConditionalBernoulli(MLP(latent_size, 784,\n                                   hidden_units=[512,256],\n                                   activation='relu',\n                                   out_lambda=lambda x: x.view(x.shape[0], 1, 28, 28)))\n\nmodel = Flow(base_dist=StandardNormal((latent_size,)),\n             transforms=[\n                VAE(encoder=encoder, decoder=decoder)\n             ])\n```\nSee [here](https://github.com/didriknielsen/survae_flows/blob/master/examples/mnist_vae.py) for a runnable example.\n\n#### Example 3: Multi-Scale Augmented Flow\n\nWe can implement e.g. dequantization, augmentation and multi-scale flows using **surjective transformations**.  \nHere, we use these layers in a *multi-scale augmented flow* for (3,32,32) images such as CIFAR-10.  \n\nNotice that this makes use of 3 types of surjective layers:\n1. **Generative rounding:** Implemented using `UniformDequantization`. Allows conversion to continuous variables. Useful for training continuous flows on ordinal discrete data.\n1. **Generative slicing:** Implemented using `Augment`. Allows increasing dimensionality towards the latent space. Useful for constructing augmented normalizing flows.\n1. **Inference slicing:** Implemented using `Slice`. Allows decreasing dimensionality towards the latent space. Useful for constructing multi-scale architectures.\n\n\n\n```python\nimport torch.nn as nn\nfrom survae.flows import Flow\nfrom survae.distributions import StandardNormal, StandardUniform\nfrom survae.transforms import AffineCouplingBijection, ActNormBijection2d, Conv1x1\nfrom survae.transforms import UniformDequantization, Augment, Squeeze2d, Slice\nfrom survae.nn.layers import ElementwiseParams2d\nfrom survae.nn.nets import DenseNet\n\ndef net(channels):\n  return nn.Sequential(DenseNet(in_channels=channels//2,\n                                out_channels=channels,\n                                num_blocks=1,\n                                mid_channels=64,\n                                depth=8,\n                                growth=16,\n                                dropout=0.0,\n                                gated_conv=True,\n                                zero_init=True),\n                        ElementwiseParams2d(2))\n\nmodel = Flow(base_dist=StandardNormal((24,8,8)),\n             transforms=[\n               UniformDequantization(num_bits=8),\n               Augment(StandardUniform((3,32,32)), x_size=3),\n               AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),\n               AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),\n               AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),\n               AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),\n               Squeeze2d(), Slice(StandardNormal((12,16,16)), num_keep=12),\n               AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),\n               AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),\n               AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),\n               AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),\n               Squeeze2d(), Slice(StandardNormal((24,8,8)), num_keep=24),\n               AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),\n               AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),\n               AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),\n               AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),\n             ])\n```\nSee [here](https://github.com/didriknielsen/survae_flows/blob/master/examples/cifar10_aug_flow.py) for a runnable example.\n\n\n#### Acknowledgements\n\nThis code base builds on several other repositories. The biggest sources of inspiration are:\n\n* https://github.com/bayesiains/nsf\n* https://github.com/pclucas14/pytorch-glow\n* https://github.com/karpathy/pytorch-made\n\nThanks to the authors of these and the many other useful repositories!\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdidriknielsen%2Fsurvae_flows","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdidriknielsen%2Fsurvae_flows","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdidriknielsen%2Fsurvae_flows/lists"}