An open API service indexing awesome lists of open source software.

https://github.com/bshall/vectorquantizedvae

A PyTorch implementation of "Continuous Relaxation Training of Discrete Latent Variable Image Models"
https://github.com/bshall/vectorquantizedvae

generative-models pytorch vae vq-vae

Last synced: 5 months ago
JSON representation

A PyTorch implementation of "Continuous Relaxation Training of Discrete Latent Variable Image Models"

Awesome Lists containing this project

README

          

# Vector Quantized VAE
A PyTorch implementation of [Continuous Relaxation Training of Discrete Latent Variable Image Models](http://bayesiandeeplearning.org/2017/papers/54.pdf).

Ensure you have Python 3.7 and PyTorch 1.2 or greater.
To train the `VQVAE` model with 8 categorical dimensions and 128 codes per dimension
run the following command:
```
python train.py --model=VQVAE --latent-dim=8 --num-embeddings=128
```
To train the `GS-Soft` model use `--model=GSSOFT`.
Pretrained weights for the `VQVAE` and `GS-Soft` models can be found
[here](https://github.com/bshall/VectorQuantizedVAE/releases/tag/v0.1).


VQVAE Reconstructions

The `VQVAE` model gets ~4.82 bpd while the `GS-soft` model gets ~4.6 bpd.

# Analysis of the Codebooks

As demonstrated in the paper, the codebook matrices are low-dimensional, spanning only a few dimensions:


Explained Variance Ratio

Projecting the codes onto the first 3 principal components shows that the codes typically tile
continuous 1- or 2-D manifolds:


Codebook principal components