https://github.com/JiaYaobo/fenbux
A Simple Statistical Distribution Library in JAX
https://github.com/JiaYaobo/fenbux
jax probabilistic-programming statistical-learning
Last synced: about 1 year ago
JSON representation
A Simple Statistical Distribution Library in JAX
- Host: GitHub
- URL: https://github.com/JiaYaobo/fenbux
- Owner: JiaYaobo
- License: apache-2.0
- Created: 2023-07-05T15:14:52.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2024-03-30T07:10:41.000Z (about 2 years ago)
- Last Synced: 2024-11-12T12:48:39.764Z (over 1 year ago)
- Topics: jax, probabilistic-programming, statistical-learning
- Language: Python
- Homepage: https://jiayaobo.github.io/fenbux/
- Size: 800 KB
- Stars: 16
- Watchers: 2
- Forks: 0
- Open Issues: 5
-
Metadata Files:
- Readme: readme.md
- License: LICENSE
Awesome Lists containing this project
README
# FenbuX
*A Simple Probalistic Distribution Library in JAX*
*fenbu* (分布, pronounce like: /fen'bu:/)-X is a simple probalistic distribution library in JAX. In fenbux, We provide you:
* A simple and easy-to-use interface like **Distributions.jl**
* Bijectors like **TensorFlow Probability** and **Bijector.jl**
* PyTree input/output
* Multiple dispatch for different distributions based on [plum-dispatch](https://github.com/beartype/plum)
* All jax feautures (vmap, pmap, jit, autograd etc.)
See [document](https://jiayaobo.github.io/fenbux/)
## Examples
### Statistics of Distributions 🤔
```python
import jax.numpy as jnp
from fenbux import variance, skewness, mean
from fenbux.univariate import Normal
μ = {'a': jnp.array([1., 2., 3.]), 'b': jnp.array([4., 5., 6.])}
σ = {'a': jnp.array([4., 5., 6.]), 'b': jnp.array([7., 8., 9.])}
dist = Normal(μ, σ)
mean(dist) # {'a': Array([1., 2., 3.], dtype=float32), 'b': Array([4., 5., 6.], dtype=float32)}
variance(dist) # {'a': Array([16., 25., 36.], dtype=float32), 'b': Array([49., 64., 81.], dtype=float32)}
skewness(dist) # {'a': Array([0., 0., 0.], dtype=float32), 'b': Array([0., 0., 0.], dtype=float32)}
```
### Random Variables Generation
```python
import jax.random as jr
from fenbux import rand
from fenbux.univariate import Normal
key = jr.PRNGKey(0)
x = {'a': {'c': {'d': {'e': 1.}}}}
y = {'a': {'c': {'d': {'e': 1.}}}}
dist = Normal(x, y)
rand(dist, key, shape=(3, )) # {'a': {'c': {'d': {'e': Array([1.6248107 , 0.69599575, 0.10169095], dtype=float32)}}}}
```
### Evaluations of Distribution 👩🎓
CDF, PDF, and more...
```python
import jax.numpy as jnp
from fenbux import cdf, logpdf
from fenbux.univariate import Normal
μ = jnp.array([1., 2., 3.])
σ = jnp.array([4., 5., 6.])
dist = Normal(μ, σ)
cdf(dist, jnp.array([1., 2., 3.])) # Array([0.5, 0.5, 0.5], dtype=float32)
logpdf(dist, jnp.array([1., 2., 3.])) # Array([-2.305233 , -2.5283763, -2.7106981], dtype=float32)
```
### Nested Transformations of Distribution 🤖
```python
import fenbux as fbx
import jax.numpy as jnp
from fenbux.univariate import Normal
# truncate and censor and affine
d = Normal(0, 1)
fbx.affine(fbx.censor(fbx.truncate(d, 0, 1), 0, 1), 0, 1)
fbx.logpdf(d, 0.5)
```
```
Array(-1.0439385, dtype=float32)
```
### Compatible with JAX transformations 😃
- vmap
```python
import jax.numpy as jnp
from jax import vmap
from fenbux import logpdf
from fenbux.univariate import Normal
dist = Normal({'a': jnp.zeros((2, 3))}, {'a':jnp.ones((2, 3, 5))}) # each batch shape is (2, 3)
x = jnp.zeros((2, 3, 5))
# claim use_batch=True to use vmap
vmap(logpdf, in_axes=(Normal(None, {'a': 2}, use_batch=True), 2))(dist, x)
```
- grad
```python
import jax.numpy as jnp
from jax import jit, grad
from fenbux import logpdf
from fenbux.univariate import Normal
dist = Normal(0., 1.)
grad(logpdf)(dist, 0.)
```
### Bijectors 🧙♂️
Evaluate a bijector
```python
import jax.numpy as jnp
from fenbux.bijector import Exp, evaluate
bij = Exp()
x = jnp.array([1., 2., 3.])
evaluate(bij, x)
```
Apply a bijector to a distribution
```python
import jax.numpy as jnp
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal
from fenbux import logpdf
dist = Normal(0, 1)
bij = Exp()
log_normal = transform(dist, bij)
x = jnp.array([1., 2., 3.])
logpdf(log_normal, x)
```
### Speed 🔦
* Common Evaluations
```python
import numpy as np
from scipy.stats import norm
from jax import jit
from fenbux import logpdf, rand
from fenbux.univariate import Normal
from tensorflow_probability.substrates.jax.distributions import Normal as Normal2
dist = Normal(0, 1)
dist2 = Normal2(0, 1)
dist3 = norm(0, 1)
x = np.random.normal(size=100000)
%timeit jit(logpdf)(dist, x).block_until_ready()
%timeit jit(dist2.log_prob)(x).block_until_ready()
%timeit dist3.logpdf(x)
```
```
51.2 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
11.1 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.12 ms ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
* Evaluations with Bijector Transformed Distributions
```python
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import jit
from fenbux import logpdf
from fenbux.bijector import Exp, transform
from fenbux.univariate import Normal
x = jnp.asarray(np.random.uniform(size=100000))
dist = Normal(0, 1)
bij = Exp()
log_normal = transform(dist, bij)
dist2 = tfd.Normal(loc=0, scale=1)
bij2 = tfb.Exp()
log_normal2 = tfd.TransformedDistribution(dist2, bij2)
def log_prob(d, x):
return d.log_prob(x)
%timeit jit(logpdf)(log_normal, x).block_until_ready()
%timeit jit(log_prob)(log_normal2, x).block_until_ready()
```
```
131 µs ± 514 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
375 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
## Installation
* Install on your local device.
```bash
git clone https://github.com/JiaYaobo/fenbux.git
pip install -e .
```
* Install from PyPI.
```bash
pip install -U fenbux
```
## Reference
* [Distributions.jl](https://github.com/JuliaStats/Distributions.jl)
* [Equinox](https://github.com/patrick-kidger/equinox)
## Citation
```bibtex
@software{fenbux,
author = {Jia, Yaobo},
title = {fenbux: A Simple Probalistic Distribution Library in JAX},
url = {https://github.com/JiaYaobo/fenbux},
year = {2024}
}
```