https://github.com/marvinsxtr/jax-flow-matching
2D Flow Matching in JAX with equinox and diffrax
https://github.com/marvinsxtr/jax-flow-matching
diffrax educational equinox flow-matching jax
Last synced: 2 months ago
JSON representation
2D Flow Matching in JAX with equinox and diffrax
- Host: GitHub
- URL: https://github.com/marvinsxtr/jax-flow-matching
- Owner: marvinsxtr
- License: mit
- Created: 2025-08-01T20:32:17.000Z (4 months ago)
- Default Branch: main
- Last Pushed: 2025-08-01T21:52:01.000Z (4 months ago)
- Last Synced: 2025-08-12T08:48:27.658Z (4 months ago)
- Topics: diffrax, educational, equinox, flow-matching, jax
- Language: Jupyter Notebook
- Homepage:
- Size: 1.16 MB
- Stars: 2
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# 2D Flow Matching in JAX

This notebook implements a simple 2D flow matching model using JAX, equinox and diffrax. It trains a neural network to learn the velocity field that transforms samples from a standard Gaussian distribution to a target checkerboard distribution.
[](https://colab.research.google.com/github/marvinsxtr/jax-flow-matching/blob/main/jax_flow_matching.ipynb)
## Overview
This notebook implements flow matching on a 2D checkerboard dataset using an affine probability path $x_t = (1-t) \cdot x_0 + t \cdot x_1$ with $x_0 \sim \mathcal{N}(0, I)$. An MLP learns the velocity field $v_\theta(x, t)$ for ODE-based sampling and likelihood estimation via Hutchinson's trace estimator.
## Dependencies
This notebook requires the following JAX ecosystem packages:
- **[JAX](https://github.com/google/jax)**: Core library for high-performance numerical computing
- **[Equinox](https://github.com/patrick-kidger/equinox)**: Neural network library built on JAX
- **[Diffrax](https://github.com/patrick-kidger/diffrax)**: Differential equation solver for JAX
- **[jaxtyping](https://github.com/google/jaxtyping)**: Type annotations for JAX arrays
- **[Optax](https://github.com/deepmind/optax)**: Gradient processing and optimization library
Additional standard packages: `numpy`, `matplotlib`
## Container Setup
This implementation can be run using Docker or Apptainer containers for reproducible results.
### Docker (Local Machine)
1. **Install VSCode Dev Containers Extension**
First, install the [Dev Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension in VSCode.
2. **Open the Repository in the Dev Container**
Click the `Reopen in Container` button in the pop-up that appears once you open the repository in VSCode.
Alternatively, open the command palette in VSCode by pressing `Shift+Alt+P` (Windows/Linux) or `Shift+Cmd+P` (Mac), and type `Dev Containers: Reopen in Container`.
### Apptainer (Cluster)
1. **Install VSCode Remote Tunnels Extension**
First, install the [Remote Tunnels](https://marketplace.visualstudio.com/items?itemName=ms-vscode.remote-server) extension in VSCode.
2. **Launch container**
To open a tunnel to connect your local VSCode to the container on the cluster:
```bash
apptainer run --nv --writable-tmpfs oras://ghcr.io/marvinsxtr/jax-flow-matching:latest-sif code tunnel
```
In VSCode press `Shift+Alt+P` (Windows/Linux) or `Shift+Cmd+P` (Mac), type "connect to tunnel", select GitHub and select your named node on the cluster. Your IDE is now connected to the cluster.
## Attribution
This implementation is based on concepts from the following sources:
- [Flow Matching Guide and Code](https://arxiv.org/abs/2412.06264)
- [Reference PyTorch Implementation](https://github.com/facebookresearch/flow_matching/blob/main/examples/2d_flow_matching.ipynb)
- [Flow Matching for Generative Modeling](https://arxiv.org/abs/2210.02747)