https://github.com/dirmeier/wgan-gp
A Wasserstein GAN with gradient penalty in Flax/NNX
https://github.com/dirmeier/wgan-gp
flax jax python wgan-gp
Last synced: 3 days ago
JSON representation
A Wasserstein GAN with gradient penalty in Flax/NNX
- Host: GitHub
- URL: https://github.com/dirmeier/wgan-gp
- Owner: dirmeier
- License: apache-2.0
- Created: 2024-12-19T11:53:12.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2025-08-21T19:27:55.000Z (10 months ago)
- Last Synced: 2026-06-21T01:31:11.342Z (3 days ago)
- Topics: flax, jax, python, wgan-gp
- Language: Python
- Homepage:
- Size: 545 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# WGAN-GP
[](https://github.com/dirmeier/wgan/actions/workflows/ci.yaml)
## About
This repository implements the [Wasserstein GAN with gradient penalty](https://arxiv.org/abs/1704.00028) loss for testing.
The implementations are in JAX and Flax/NNX.
## Example usage
An experiment where we train a WGAN-GP on MNIST can be found in [`experiments/mnist/`](experiments/mnist/).
To run the example, first download the latest release and install all dependencies via:
```bash
wget -qO- https://github.com/dirmeier/wgan-gp/archive/refs/tags/.tar.gz | tar zxvf -
uv sync --all-groups
```
To train a model and make visualizations, call:
```bash
cd experiments/eight_gaussians_two_moons
python main.py
```
Below are the results from training the GN using the hyperparameters defined in [`experiments/mnist/config.py`](experiments/mnist/config.py).
A sample after training 20k steps (i.e., gradient steps) is shown below.
## Installation
To install the latest GitHub , just call the following on the
command line:
```bash
pip install git+https://github.com/dirmeier/wgan@
```
## Author
Simon Dirmeier simd23 @ pm dot me