https://github.com/igreat/variational-autoencoders-pytorch
I'll try to implement variational autoencoders in pytorch :D
https://github.com/igreat/variational-autoencoders-pytorch
Last synced: 4 months ago
JSON representation
I'll try to implement variational autoencoders in pytorch :D
- Host: GitHub
- URL: https://github.com/igreat/variational-autoencoders-pytorch
- Owner: igreat
- Created: 2023-02-17T07:59:54.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2023-02-18T04:26:47.000Z (over 2 years ago)
- Last Synced: 2025-06-26T10:01:45.628Z (4 months ago)
- Language: Jupyter Notebook
- Size: 3.32 MB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Variational Autoencoder PyTorch 🌌🎲ðŸ§
This is a PyTorch implementation of a Variational Autoencoder (VAE). [Paper reference]("https://arxiv.org/pdf/1312.6114.pdf) by Kingma & Welling
*currently under development*
# Autoencoders 🤖
**Autoencoders** are neural networks that try to reconstruct their input after being *reduced to a lower dimentional*. The autoencoder consists of two parts:
- **Encoder**: reduces the input to a lower dimentional representation (inside the ***latent space***).
- **Decoder**: tries to reconstruct the input from the lower dimentional representation (from the *latent code*).
The key idea is that if the network is able to reconstruct the input from a lower dimentional representation, then it must have learned non-trivial and rich features of the input. The latent code can now be used for some downstream task (e.g. classification, clustering, etc).
# Variational Autoencoders (VAE) 🌌
Unfortunately, the latent space learned by a regular autoencoder is highly irregular, and we can't sample from it. Think about the original image space as a manifold inside a higher dimensional space [**(check out my video on the manifold hypothesis)**](https://www.youtube.com/watch?v=pdNYw6qwuNc). The latent space is a projection of the manifold into a lower dimensional space.
As you can see, the problem is that we can't just pick a random point in the latent space and expect to get a valid image, since there is technically a probability 0 of getting a point within the original manifold *(think about the probability of getting a specific point inside a square when sampling a random point from it)*.
**Variational Autoencoders** try to solve this problem by enforcing a **Gaussian distribution** over the latent space. This way, we can sample from a Gaussian (which is our latent space) and get a valid image.
![]()
Plotting a 3d gaussian function. Source.
Thus, the loss function for a VAE is composed of two parts:
- **Reconstruction loss**: The reconstruction loss can be thought of as being the same as the loss function of a regular autoencoder. It tries to reconstruct the input from the latent code. In our case, we use [MSE](https://en.wikipedia.org/wiki/Mean_squared_error).
- **KL Divergence**: The [KL Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) is a measure of how different two distributions are. In our case, we want to measure how different the latent space distribution is from a unit diagonal Gaussian distribution.# VAE architecture ðŸ§
I use a simple CNN based architecture for the encoder and decoder.
- Encoder consists of:
- Convolutional layers
- Residual connections
- Max pooling (downsampling)
- Decoder consists of:
- Transpose convolutional layers
- Nearest neighbor upsampling
![]()
Diagram of similar architecture. Source.
**The encoder now outputs the `mean` and `log variance` of the latent space distribution.** We sample from this distribution and pass it to the decoder. The decoder then tries to **reconstruct the input** from the sampled latent space.
![]()
From lecture slides from EECS 498 Michigan University course. Source.
The KL divergence is computed as follows:
```python
def kl_divergence(mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
```
And the reconstruction loss is computed as follows:
```python
def reconstruction_loss(x_hat, x):
return F.mse_loss(x_hat, x, reduction="sum")
```
**The total loss is just the sum of the two.**The encoder and decoder are trained jointly to minimize both the reconstruction loss and the KL divergence. From this prespective, the KL divergence can be thought of as a **regularization term** that forces the latent space distribution to be as close as possible to a unit diagonal Gaussian distribution.
**Note**: we use unit diagonal Gaussian distribution because it's simple and convenient (we can compute the KL-Divergence in closed form). If we didn't use just the diagonal, we'd end up needing to compute the entire covariance matrix, and our weight matrices will be astronomically large.
# Results 🎲
The VAE implemented in this repository is able to reconstruct images pretty well. Here are some examples (left is the original image, right is the reconstructed image):
![]()
![]()
Also, now the latent space is a Gaussian distribution, so we can sample from it and generate new images. Here I sample 64 images from the latent space:
![]()
I also tried encoding a real image into the latent space, and then slightly tweaking its latent code to generate a new image. Here are some examples where vertically I vary one dimention of the latent code, and horizontally I vary another dimention:
![]()
As you can see, the vertical latent code seems to encode how much of a `1` an image is, and the vertical one encodes how much of a `0` an image is.
TODO: experiment on datasets other than MNIST (e.g. CIFAR-10, CelebA, etc).