https://github.com/bshall/vectorquantizedvae
A PyTorch implementation of "Continuous Relaxation Training of Discrete Latent Variable Image Models"
https://github.com/bshall/vectorquantizedvae
generative-models pytorch vae vq-vae
Last synced: 5 months ago
JSON representation
A PyTorch implementation of "Continuous Relaxation Training of Discrete Latent Variable Image Models"
- Host: GitHub
- URL: https://github.com/bshall/vectorquantizedvae
- Owner: bshall
- License: mit
- Created: 2019-09-10T08:17:55.000Z (about 6 years ago)
- Default Branch: master
- Last Pushed: 2020-03-25T18:26:31.000Z (over 5 years ago)
- Last Synced: 2025-04-08T14:45:48.776Z (6 months ago)
- Topics: generative-models, pytorch, vae, vq-vae
- Language: Jupyter Notebook
- Size: 975 KB
- Stars: 73
- Watchers: 1
- Forks: 16
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Vector Quantized VAE
A PyTorch implementation of [Continuous Relaxation Training of Discrete Latent Variable Image Models](http://bayesiandeeplearning.org/2017/papers/54.pdf).Ensure you have Python 3.7 and PyTorch 1.2 or greater.
To train the `VQVAE` model with 8 categorical dimensions and 128 codes per dimension
run the following command:
```
python train.py --model=VQVAE --latent-dim=8 --num-embeddings=128
```
To train the `GS-Soft` model use `--model=GSSOFT`.
Pretrained weights for the `VQVAE` and `GS-Soft` models can be found
[here](https://github.com/bshall/VectorQuantizedVAE/releases/tag/v0.1).
![]()
The `VQVAE` model gets ~4.82 bpd while the `GS-soft` model gets ~4.6 bpd.
# Analysis of the Codebooks
As demonstrated in the paper, the codebook matrices are low-dimensional, spanning only a few dimensions:
![]()
Projecting the codes onto the first 3 principal components shows that the codes typically tile
continuous 1- or 2-D manifolds:
![]()