https://github.com/matthias-wright/jax-fid
FID computation in Jax/Flax.
https://github.com/matthias-wright/jax-fid
fid flax frechet-inception-distance jax
Last synced: over 1 year ago
JSON representation
FID computation in Jax/Flax.
- Host: GitHub
- URL: https://github.com/matthias-wright/jax-fid
- Owner: matthias-wright
- License: apache-2.0
- Created: 2021-07-29T04:05:43.000Z (almost 5 years ago)
- Default Branch: main
- Last Pushed: 2024-07-17T07:19:14.000Z (almost 2 years ago)
- Last Synced: 2025-02-27T23:09:57.130Z (over 1 year ago)
- Topics: fid, flax, frechet-inception-distance, jax
- Language: Python
- Homepage:
- Size: 55.9 MB
- Stars: 27
- Watchers: 3
- Forks: 5
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-jax - FID computation - Port of [mseitzer/pytorch-fid](https://github.com/mseitzer/pytorch-fid) to Flax. (Models and Projects / Flax)
README
# FID computation in Jax/Flax
This is a port of [mseitzer/pytorch-fid](https://github.com/mseitzer/pytorch-fid), which is a port of the original FID implementation ([bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR)).
The parameters for the [InceptionV3](https://arxiv.org/abs/1512.00567) network are taken from [mseitzer/pytorch-fid](https://github.com/mseitzer/pytorch-fid). The FID scores are almost identical (absolute difference around 1e-7).
The only difference is that [mseitzer/pytorch-fid](https://github.com/mseitzer/pytorch-fid) [resizes](https://github.com/mseitzer/pytorch-fid/blob/d042ab8a9f8e4b388c21bc7b38d9599c5fbcfe7b/src/pytorch_fid/inception.py#L146) the images to 299x299 by default. In this implementation, the images are not resized by default. You can resize the images using the `--img_size` argument.
## Installation
You will need Python 3.7 or later.
1. For GPU usage, follow the Jax installation with CUDA.
2. Then install:
```sh
> pip install jax-fid
```
For CPU-only you can skip step 1.
## Usage
### Compute FID score
```python
> CUDA_VISIBLE_DEVICES=N python -m jax_fid --path1 /path/to/dataset1 --path2 /path/to/dataset2
```
where `N` is the GPU index.
### Pre-compute statistics for image directory
```python
> CUDA_VISIBLE_DEVICES=N python -m jax_fid --precompute --img_dir /path/to/dataset --out_dir /path/to/stats
```
### Arguments
`--path1` - Path to image directory or .npz file containing pre-computed statistics.
`--path2` - Path to image directory or .npz file containing pre-computed statistics.
`--batch_size` - Batch size per device for computing the Inception activations.
`--img_size` - Resize images to this size. The format is (height, width).
`--precompute` - If True, pre-compute statistics for given image directory.
`--img_dir` - Path to image directory for pre-computing statistics.
`--out_dir` - Path where pre-computed statistics are stored.
`--mmap` - If True, use mmap to compute statistics.
`--mmap_file` - Name of mmap file. Only used if mmap is True.
## License
Apache-2.0 License