https://github.com/lockwo/isax
Ising (like) models in JAX
https://github.com/lockwo/isax
gibbs-sampling ising-model jax
Last synced: 5 months ago
JSON representation
Ising (like) models in JAX
- Host: GitHub
- URL: https://github.com/lockwo/isax
- Owner: lockwo
- License: apache-2.0
- Created: 2025-08-24T18:47:55.000Z (9 months ago)
- Default Branch: main
- Last Pushed: 2025-11-29T21:54:48.000Z (6 months ago)
- Last Synced: 2025-12-02T02:52:39.757Z (6 months ago)
- Topics: gibbs-sampling, ising-model, jax
- Language: Python
- Homepage: https://lockwo.github.io/isax/
- Size: 1.12 MB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
# isax
isax is a [JAX](https://github.com/google/jax)-based library for sampling from Ising models using blocked Gibbs sampling. It supports hypergraphs, flexible sampling/modeling, and all the usual JAX transformations. isax is heavily inspired by [thrml](https://docs.thrml.ai/en/latest/) and [isingtorch](https://github.com/lockwo/isingtorch).
## Documentation
Available at https://lockwo.github.io/isax
## Installation
```bash
git clone https://github.com/lockwo/isax
cd isax
pip install -e .
```
Requires Python 3.10+.
## Quick Example
```python
import jax
import jax.numpy as jnp
from isax import BlockGraph, Edge, Node, IsingModel, IsingSampler, SamplingArgs, sample_chain
L = 4
nodes = [Node() for _ in range(L * L)]
edges = []
for x in range(L):
for y in range(L):
i = x * L + y
edges.append(Edge(nodes[i], nodes[(x * L + (y + 1) % L)]))
edges.append(Edge(nodes[i], nodes[((x + 1) % L) * L + y]))
even = [nodes[x * L + y] for x in range(L) for y in range(L) if (x + y) % 2 == 0]
odd = [nodes[x * L + y] for x in range(L) for y in range(L) if (x + y) % 2 == 1]
graph = BlockGraph([even, odd], edges)
params = graph.get_sampling_params()
model = IsingModel(weights=jnp.ones(len(edges)), biases=jnp.zeros(L * L))
sampler = IsingSampler()
sampling_args = SamplingArgs(gibbs_steps=100, blocks_to_sample=[0, 1], data=params)
key = jax.random.key(0)
init_state = [jax.random.choice(key, jnp.array([-1, 1]), (len(even),)),
jax.random.choice(key, jnp.array([-1, 1]), (len(odd),))]
samples = sample_chain(init_state, [sampler, sampler], model, sampling_args, key)
```
## Future Additions (TODO)
- [x] annealing
- [x] cleaner interface
- [ ] improve example documentation/math background
- [ ] add tests
- [ ] runtime sampling params
- [ ] generic block typing
- [ ] generalize pytree typing for states
- [ ] support non-gibbs samplers (wolff, mh, etc.)
- [ ] add some ML examples