{"id":17652652,"url":"https://github.com/snakers4/playing_with_vae","last_synced_at":"2025-07-13T23:08:51.880Z","repository":{"id":97711483,"uuid":"140070866","full_name":"snakers4/playing_with_vae","owner":"snakers4","description":"Comparing FC VAE / FCN VAE / PCA / UMAP on MNIST / FMNIST","archived":false,"fork":false,"pushed_at":"2018-07-07T11:31:27.000Z","size":4697,"stargazers_count":64,"open_issues_count":0,"forks_count":13,"subscribers_count":3,"default_branch":"master","last_synced_at":"2025-05-07T08:45:00.314Z","etag":null,"topics":["beginner","embedding","fashion-mnist","mnist","pca","python3","pytorch","tutorial","umap","variational-autoencoder"],"latest_commit_sha":null,"homepage":null,"language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/snakers4.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null}},"created_at":"2018-07-07T09:30:00.000Z","updated_at":"2025-04-14T21:53:15.000Z","dependencies_parsed_at":"2023-06-02T21:15:09.525Z","dependency_job_id":null,"html_url":"https://github.com/snakers4/playing_with_vae","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/snakers4/playing_with_vae","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/snakers4%2Fplaying_with_vae","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/snakers4%2Fplaying_with_vae/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/snakers4%2Fplaying_with_vae/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/snakers4%2Fplaying_with_vae/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/snakers4","download_url":"https://codeload.github.com/snakers4/playing_with_vae/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/snakers4%2Fplaying_with_vae/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":265218750,"owners_count":23729528,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["beginner","embedding","fashion-mnist","mnist","pca","python3","pytorch","tutorial","umap","variational-autoencoder"],"created_at":"2024-10-23T11:47:34.164Z","updated_at":"2025-07-13T23:08:51.833Z","avatar_url":"https://github.com/snakers4.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"![Latent vector spaces](compare.png)\n\n\n\n![FMNIST reconstructions](reconstructions.png)\n\n# Intro\n\nThis is a test task I did for some reason.\nIt contains evaluation of:\n- FC VAE / FCN VAE on MNIST / FMNIST for image reconstruction;\n- Comparison of embeddings produced by VAE / PCA / UMAP for classification;\n\n# TLDR\n\nWhat you can find here:\n- A working VAE example on PyTorch with a lot of flags (both FC and FCN, as well as a number of failed experiments);\n- Some experiment boilerplate code;\n- Comparison between embeddings produced by PCA / UMAP / VAEs (**spoiler** - VAEs win);\n- A step-by step logic of what I did in `main.ipynb`\n\n\n# Docker environment\n\nTo build the docker image from the Dockerfile located in `dockerfile` please do:\n```\ncd dockerfile\ndocker build -t vae_docker .\n```\n(you can replace public ssh key with yours, ofc)\n\nAlso please make sure that [nvidia-docker2](https://github.com/nvidia/nvidia-docker/wiki/Installation-(version-2.0)) and proper nvidia drivers are installed.\n\nTo test the installation run\n```\ndocker run --runtime=nvidia --rm nvidia/cuda nvidia-smi\n```\n\nThen launch the container as follows:\n```\ndocker run --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=0 -it -v /your/folder/:/home/keras/notebook/your_folder -p 8888:8888 -p 6006:6006 --name vae --shm-size 16G vae_docker\n```\n\nPlease note that w/o `--shm-size 16G` PyTorch dataloader classes will not work.\nThe above command will start a container with a Jupyter notebook server available via port `8888`. \nPort `6006` is for tensorboard, if necessary.\n\nThen you can exec into the container like this. All the scripts were run as root, but they must also work under user `keras`\n```\ndocker exec -it --user root REPLACE_WITH_CONTAINER_ID /bin/bash\n```\nor\n```\ndocker exec -it --user keras REPLACE_WITH_CONTAINER_ID /bin/bash\n```\n\nTo find out the container ID run\n```\n docker container ls\n```\n\n\n# Most important dependencies (if you do not want docker)\n\nThese are the most important dependencies (others you can just install in the progress):\n```\nUbuntu 16.04\ncuda 9.0\ncudnn 7\npython 3.6\npip\nPIL\ntensorflow-gpu (for tensorboard)\npandas\nnumpy\nmatplotlib\nseaborn\ntqdm\nscikit-learn\npytorch 0.4.0 (cuda90)\ntorchvision 2.0\ndatashader\numap\n```\nIf you have trouble with these, look up how I install them in the Dockerfile / jupyter notebook.\n\n\n# Results\n\n## VAE\n\n**The best model can be trained as follows**\n\n```\npython3 train.py \\\n\t--epochs 30 --batch-size 512 --seed 42 \\\n\t--model_type fc_conv --dataset_type fmnist --latent_space_size 10 \\\n\t--do_augs False \\\n\t--lr 1e-3 --m1 40 --m2 50 \\\n\t--optimizer adam \\\n\t--do_running_mean False --img_loss_weight 1.0 --kl_loss_weight 1.0 \\\n\t--image_loss_type bce --ssim_window_size 5 \\\n\t--print-freq 10 \\\n\t--lognumber fmnist_fc_conv_l10_rebalance_no_norm \\\n\t--tensorboard True --tensorboard_images True \\\n```\n\nIf you launch this code, the copy of `FMNIST` dataset will be dowloaded automatically.\n\nSuggested alternative values for the flags for playing with them:\n- `dataset_type` - can be set to `mnist` and `fmnist`. In each case will download the necessary dataset\n- `latent_space_size` - will affect the latent space in combination with `model_type` `fc_conv` or `fc`. Other model types do not work properly\n- `m1` and `m2` control lr decay, but it did not really help here\n- `image_loss_type` can be set to `bce`, `mse` or `ssim`. In practice `bce` works best. `mse` is worse. I suppose that proper scaling is required to make it work with `ssim` (it does not train now)\n- `tensorboard`  and `tensorboard_images` can also be set to `False`. But they just write logs, so you may just not bother\n\nThese flags are optional `--tensorboard True --tensorboard_images True`, in order to use them, you have to \n- install tensorboard (installs with tensorflow)\n- launch tensorboard with the following command `tensorboard --logdir='path/to/tb_logs' --port=6006`\n\nYou can also resume from the best checkpoint using these flags:\n```\npython3 train.py \\\n\t--resume weights/fmnist_fc_conv_l10_rebalance_no_norm_best.pth.tar \\\n\t--epochs 60 --batch-size 512 --seed 42 \\\n\t--model_type fc_conv --dataset_type fmnist --latent_space_size 10 \\\n\t--do_augs False \\\n\t--lr 1e-3 --m1 50 --m2 100 \\\n\t--optimizer adam \\\n\t--do_running_mean False --img_loss_weight 1.0 --kl_loss_weight 1.0 \\\n\t--image_loss_type bce --ssim_window_size 5 \\\n\t--print-freq 10 \\\n\t--lognumber fmnist_resume \\\n\t--tensorboard True --tensorboard_images True \\\n```\n\nThe best reconstructions are supposed to look like this (top row - original images, bottow row - reconstructions):\n![](reconstructions.png)\n\n**Brief ablation analysis of the results**\n\n**✓ What worked**\n1. Using BCE loss + KLD loss\n2. Converting a plain FC model into a conv model in the most straight-forward fashion possible, i.e. replacing this\n```\n        self.fc1 = nn.Linear(784, 400)\n        self.fc21 = nn.Linear(400, latent_space_size)\n        self.fc22 = nn.Linear(400, latent_space_size)\n        self.fc3 = nn.Linear(latent_space_size, 400)\n        self.fc4 = nn.Linear(400, 784)\n```        \nwith this\n```\n        self.fc1 = nn.Conv2d(1,32, kernel_size=(28,28), stride=1, padding=0)\n        self.fc21 = nn.Conv2d(32,latent_space_size, kernel_size=(1,1), stride=1, padding=0)\n        self.fc22 = nn.Conv2d(32,latent_space_size, kernel_size=(1,1), stride=1, padding=0)\n        \n        self.fc3 = nn.ConvTranspose2d(latent_space_size,118, kernel_size=(1,1),  stride=1, padding=0)\n        self.fc4 = nn.ConvTranspose2d(118,1, kernel_size=(28,28),  stride=1, padding=0)\n```        \n3. Using `SSIM` as visualization metric. It correlates awesomely with perceived visual similarity of the image and its reconstruction\n\n\n**✗ What did not work**\n1. Extracting `mean` and `std` from images - removing this feature boosted SSIM on FMNIST 4-5x\n2. Doing any simple augmentations (unsurprisingly - it adds a complexity level to a simple task)\n3. Any architectures beyond the most obvious ones:\n    - UNet inspired architectures (my speculation - this is because image size is very small, and very global features work best, i.e. feature extraction cascade is overkill)\n    - I tried various combinations of convolution weights, all of them did not work\n    - 1xN convolutions\n4. `MSE` loss performed poorly, `SSIM` loss did not work at all\n5. LR decay, as well as any LR besides `1e-3` (with adam) does not really help\n6. Increasing latent space to `20` or `100` does not really change much\n\n** ¯|_(ツ)_/¯ What I did not try**\n1. Ensembling or building meta-architectures\n2. Conditional VAEs\n3. Increasing network capacity\n\n## PCA vs. UMAP vs. VAE\n\nPlease refer to section 5 of the `main.ipynb`\n\nIs notable that:\n- VAEs visually worked better than PCA;\n- Using the VAE embedding for classification produces higher accuracty (~80% vs. 73%);\n- A similar accuracy on train/val can be obtained using [UMAP](https://github.com/lmcinnes/umap);\n\nJupyter notebook (.ipynb file) is best viewed using these Jupiter notebook extensions (installed with the below command, then to be turned on in the **Jupyter GUI**)\n\n```\npip install git+https://github.com/ipython-contrib/jupyter_contrib_nbextensions\n# conda install html5lib==0.9999999\njupyter contrib nbextension install --system\n```    \nSometims there is a `html5lib` conflict.\nExcluded from the Dockerfile because of this conflict (sometimes occurs, sometimes not).\n![](extensions.png)\n\n# Further reading\n\n- (EN) A small intuitive intro (super super cool and intuitive)  https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf\n- (EN) KL divergence explained https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained\n- (EN) A more formal write-up http://arxiv.org/abs/1606.05908\n- (RU) A cool post series on habr about auto-encoders https://habr.com/post/331382/\n- (EN) Converting a FC layer into a conv layer http://cs231n.github.io/convolutional-networks/#convert\n- (EN) A VAE post by Fchollet https://blog.keras.io/building-autoencoders-in-keras.html\n- (EN) Why VAEs are not used on larger datasets https://www.quora.com/Why-is-there-no-work-of-variational-auto-encoder-on-larger-data-sets-like-CIFAR-or-ImageNet\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsnakers4%2Fplaying_with_vae","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fsnakers4%2Fplaying_with_vae","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsnakers4%2Fplaying_with_vae/lists"}