An open API service indexing awesome lists of open source software.

https://github.com/dirmeier/wgan-gp

A Wasserstein GAN with gradient penalty in Flax/NNX
https://github.com/dirmeier/wgan-gp

flax jax python wgan-gp

Last synced: 3 days ago
JSON representation

A Wasserstein GAN with gradient penalty in Flax/NNX

Awesome Lists containing this project

README

          

# WGAN-GP

[![ci](https://github.com/dirmeier/wgan-gp/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/wgan/actions/workflows/ci.yaml)

## About

This repository implements the [Wasserstein GAN with gradient penalty](https://arxiv.org/abs/1704.00028) loss for testing.
The implementations are in JAX and Flax/NNX.

## Example usage

An experiment where we train a WGAN-GP on MNIST can be found in [`experiments/mnist/`](experiments/mnist/).
To run the example, first download the latest release and install all dependencies via:

```bash
wget -qO- https://github.com/dirmeier/wgan-gp/archive/refs/tags/.tar.gz | tar zxvf -
uv sync --all-groups
```

To train a model and make visualizations, call:

```bash
cd experiments/eight_gaussians_two_moons
python main.py
```

Below are the results from training the GN using the hyperparameters defined in [`experiments/mnist/config.py`](experiments/mnist/config.py).
A sample after training 20k steps (i.e., gradient steps) is shown below.



## Installation

To install the latest GitHub , just call the following on the
command line:

```bash
pip install git+https://github.com/dirmeier/wgan@
```

## Author

Simon Dirmeier simd23 @ pm dot me