{"id":13784731,"url":"https://github.com/ChrisWaites/jax-flows","last_synced_at":"2025-05-11T20:31:10.455Z","repository":{"id":40960260,"uuid":"249319428","full_name":"ChrisWaites/jax-flows","owner":"ChrisWaites","description":"Normalizing Flows in JAX 🌊","archived":false,"fork":false,"pushed_at":"2023-06-18T17:13:40.000Z","size":6013,"stargazers_count":283,"open_issues_count":10,"forks_count":19,"subscribers_count":7,"default_branch":"master","last_synced_at":"2025-04-27T01:16:47.821Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/ChrisWaites.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":".github/CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null}},"created_at":"2020-03-23T02:42:43.000Z","updated_at":"2025-04-20T13:43:57.000Z","dependencies_parsed_at":"2024-01-17T04:19:47.144Z","dependency_job_id":null,"html_url":"https://github.com/ChrisWaites/jax-flows","commit_stats":{"total_commits":44,"total_committers":8,"mean_commits":5.5,"dds":0.6136363636363636,"last_synced_commit":"26dce814478c656b2ed7e3295ec17b09cad200ee"},"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChrisWaites%2Fjax-flows","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChrisWaites%2Fjax-flows/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChrisWaites%2Fjax-flows/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChrisWaites%2Fjax-flows/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/ChrisWaites","download_url":"https://codeload.github.com/ChrisWaites/jax-flows/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253631926,"owners_count":21939368,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-08-03T19:00:51.808Z","updated_at":"2025-05-11T20:31:09.334Z","avatar_url":"https://github.com/ChrisWaites.png","language":"Python","funding_links":[],"categories":["Libraries","Python","📦 Packages \u003csmall\u003e(15)\u003c/small\u003e"],"sub_categories":["Inactive Libraries","\u003cimg src=\"assets/jax.svg\" alt=\"JAX\" height=\"20px\"\u003e \u0026nbsp;JAX Packages","New Libraries"],"readme":"\u003cimg align=\"right\" width=\"300\" src=\"assets/flows.gif\"\u003e\n\n# Normalizing Flows in JAX\n\n\u003c!--\n\u003ca href=\"https://circleci.com/gh/ChrisWaites/jax-flows\"\u003e\n    \u003cimg alt=\"Build\" src=\"https://img.shields.io/circleci/build/github/ChrisWaites/jax-flows/master\"\u003e\n\u003c/a\u003e\n--\u003e\n\u003ca href=\"https://github.com/ChrisWaites/jax-flows/blob/master/LICENSE\"\u003e\n    \u003cimg alt=\"GitHub\" src=\"https://img.shields.io/github/license/ChrisWaites/jax-flows.svg?color=blue\"\u003e\n\u003c/a\u003e\n\u003ca href=\"https://jax-flows.readthedocs.io/en/latest/\"\u003e\n    \u003cimg alt=\"Documentation\" src=\"https://img.shields.io/website/http/jax-flows.readthedocs.io.svg?down_color=red\u0026down_message=offline\u0026up_message=online\"\u003e\n\u003c/a\u003e\n\n\u003cp\u003eImplementations of normalizing flows (RealNVP, Glow, MAF) in the \u003ca href=\"https://github.com/google/jax/\"\u003eJAX\u003c/a\u003e deep learning framework.\u003c/p\u003e\n\n## What are normalizing flows?\n\n[Normalizing flow models](http://akosiorek.github.io/ml/2018/04/03/norm_flows.html) are _generative models_, i.e. they infer the underlying probability distribution of an observed dataset. With that distribution we can do a number of interesting things, namely sample new realistic points and query probability densities.\n\n## Why JAX?\n\nA few reasons!\n\n1) JAX encourages a functional style. When writing a layer, I didn't want people to worry about PyTorch or TensorFlow boilerplate and how their code has to fit into \"the system\" (e.g. do I have to keep track of `self.training` here?) _All_ you have to worry about is writing a vanilla python function which, given an ndarray, returns the correct set of outputs. You could develop your own layers with effectively no knowledge of the encompassing framework.\n\n2) JAX's [random number generation system](https://github.com/google/jax/blob/master/design_notes/prng.md) places reproducibility first. To get a sense for this, when you start to parallelize a system, centralized state-based models for PRNG a la `torch.manual_seed()` or `tf.random.set_seed()` start to yield inconsistent results. Given that randomness is such a central component to work in this area, I thought that uncompromising reproducibility would be a nice feature.\n\n3) JAX has a really flexible automatic differentiation system. So flexible, in fact, that you can (basically) write arbitrary python functions (including for loops, if statements, etc.) and automatically compute their jacobian with a call to `jax.jacfwd`. So, in theory, you could write a normalizing flow layer and automatically compute its jacobian's log determinant without having to do so manually (although we're not quite there yet).\n\n## How do things work?\n\nHere's an introduction! But for a more comprehensive description, check out the [documentation](https://jax-flows.readthedocs.io/).\n\n### Bijections\n\nA `bijection` is a parameterized invertible function.\n\n```python\ninit_fun = flows.InvertibleLinear()\n\nparams, direct_fun, inverse_fun = init_fun(rng, input_dim=5)\n\n# Transform inputs\ntransformed_inputs, log_det_jacobian_direct = direct_fun(params, inputs)\n\n# Reconstruct original inputs\nreconstructed_inputs, log_det_jacobian_inverse = inverse_fun(params, transformed_inputs)\n\nassert np.array_equal(inputs, reconstructed_inputs)\n```\n\nWe can construct a sequence of bijections using `flows.Serial`. The result is just another bijection, and adheres to the exact same interface.\n\n```python\ninit_fun = flows.Serial(\n    flows.AffineCoupling(transformation),\n    flows.InvertibleLinear(),\n    flows.ActNorm(),\n)\n\nparams, direct_fun, inverse_fun = init_fun(rng, input_dim=5)\n```\n\n### Distributions\n\nA `distribution` is characterized by a probability density querying function, a sampling function, and its parameters.\n\n```python\ninit_fun = flows.Normal()\n\nparams, log_pdf, sample = init_fun(rng, input_dim=5)\n\n# Query probability density of points\nlog_pdfs = log_pdf(params, inputs)\n\n# Draw new points\nsamples = sample(rng, params, num_samples)\n```\n\n### Normalizing Flow Models\n\nUnder this definition, a normalizing flow model is just a `distribution`. But to retrieve one, we have to give it a `bijection` and another `distribution` to act as a prior.\n\n```python\nbijection = flows.Serial(\n    flows.AffineCoupling(transformation),\n    flows.InvertibleLinear(),\n    flows.ActNorm(),\n)\n\nprior = flows.Normal()\n\ninit_fun = flows.Flow(bijection, prior)\n\nparams, log_pdf, sample = init_fun(rng, input_dim=5)\n```\n\n### How do I train a model?\n\nThe same as you always would in JAX! First, define an appropriate loss function and parameter update step.\n\n```python\ndef loss(params, inputs):\n    return -log_pdf(params, inputs).mean()\n\n@jit\ndef step(i, opt_state, inputs):\n    params = get_params(opt_state)\n    gradient = grad(loss)(params, inputs)\n    return opt_update(i, gradient, opt_state)\n```\n\nThen execute a standard JAX training loop.\n\n```python\nbatch_size, itercount = 32, itertools.count()\n\nfor epoch in range(num_epochs):\n    npr.shuffle(X)\n    for batch_index in range(0, X.shape[0], batch_size):\n        opt_state = step(\n            next(itercount),\n            opt_state,\n            X[batch_index:batch_index+batch_size]\n        )\n\noptimized_params = get_params(opt_state)\n```\n\nNow that we have our trained model parameters, we can query and sample as regular.\n\n```python\nlog_pdfs = log_pdf(optimized_params, inputs)\n\nsamples = sample(rng, optimized_params, num_samples)\n```\n\n_Magic!_\n\n## Interested in contributing?\n\nYay! Check out our [contributing guidelines](https://github.com/ChrisWaites/jax-flows/blob/master/.github/CONTRIBUTING.md).\n\n## Inspiration\n\nThis repository is largely modeled after the [`pytorch-flows`](https://github.com/ikostrikov/pytorch-flows) repository by [Ilya Kostrikov](https://github.com/ikostrikov), the [`nf-jax`](https://github.com/ericjang/nf-jax) repository by [Eric Jang](http://evjang.com/), and the [`normalizing-flows`](https://github.com/tonyduan/normalizing-flows) repository by [Tony Duan](https://github.com/tonyduan).\n\nThe implementations are modeled after the work of the following papers:\n\n  \u003e [NICE: Non-linear Independent Components Estimation](https://arxiv.org/abs/1410.8516)\\\n  \u003e Laurent Dinh, David Krueger, Yoshua Bengio\\\n  \u003e _arXiv:1410.8516_\n\n  \u003e [Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)\\\n  \u003e Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio\\\n  \u003e _arXiv:1605.08803_\n\n  \u003e [Improving Variational Inference with Inverse Autoregressive Flow\n](https://arxiv.org/abs/1606.04934)\\\n  \u003e Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling\\\n  \u003e _arXiv:1606.04934_\n\n  \u003e [Glow: Generative Flow with Invertible 1x1 Convolutions](https://arxiv.org/abs/1807.03039)\\\n  \u003e Diederik P. Kingma, Prafulla Dhariwal\\\n  \u003e _arXiv:1807.03039_\n\n  \u003e [Flow++: Improving Flow-Based Generative Models\n  with Variational Dequantization and Architecture Design](https://openreview.net/forum?id=Hyg74h05tX)\\\n  \u003e Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel\\\n  \u003e _OpenReview:Hyg74h05tX_\n\n  \u003e [Masked Autoregressive Flow for Density Estimation](https://arxiv.org/abs/1705.07057)\\\n  \u003e George Papamakarios, Theo Pavlakou, Iain Murray\\\n  \u003e _arXiv:1705.07057_\n\n  \u003e [Neural Spline Flows](https://arxiv.org/abs/1906.04032)\\\n  \u003e Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios\\\n  \u003e _arXiv:1906.04032_\n\nAnd by association the following surveys:\n\n  \u003e [Normalizing Flows: An Introduction and Review of Current Methods](https://arxiv.org/abs/1908.09257)\\\n  \u003e Ivan Kobyzev, Simon Prince, Marcus A. Brubaker\\\n  \u003e _arXiv:1908.09257_\n\n  \u003e [Normalizing Flows for Probabilistic Modeling and Inference](https://arxiv.org/abs/1912.02762)\\\n  \u003e George Papamakarios, Eric Nalisnick, Danilo Jimenez Rezende, Shakir Mohamed, Balaji Lakshminarayanan\\\n  \u003e _arXiv:1912.02762_\n\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FChrisWaites%2Fjax-flows","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FChrisWaites%2Fjax-flows","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FChrisWaites%2Fjax-flows/lists"}