https://github.com/yassa9/doodlevae
Yet another disentangled VAE ... but for quick drawing doodles
https://github.com/yassa9/doodlevae
deep-learning doodle face generative-ai latent-space latent-variable-models ml nn pytorch quickdr quickdraw-dataset vae vae-pytorch
Last synced: 6 months ago
JSON representation
Yet another disentangled VAE ... but for quick drawing doodles
- Host: GitHub
- URL: https://github.com/yassa9/doodlevae
- Owner: yassa9
- Created: 2025-04-02T02:13:46.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2025-04-02T06:09:02.000Z (12 months ago)
- Last Synced: 2025-04-13T11:14:33.810Z (11 months ago)
- Topics: deep-learning, doodle, face, generative-ai, latent-space, latent-variable-models, ml, nn, pytorch, quickdr, quickdraw-dataset, vae, vae-pytorch
- Language: Python
- Homepage:
- Size: 1.14 MB
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
doodleVAE
Yet another disentangled VAE ... but for quick drawing doodles.
[![GIF shot][product-screenshot]](https://github.com/yassa9/doodleVAE)
## The Model
This project implements a `Variational Autoencoder (VAE)` trained to generate hand-drawn doodles. It learns a compressed latent representation of doodles from the `Quick, Draw! dataset` and uses it to generate new, human-like sketches.
### The model consists of:
- [x] Convolutional encoder that compresses input images into a latent vector
- [x] Reparameterization layer to sample from the latent space
- [x] Convolutional decoder that reconstructs images from latent vectors
- [x] Latent space exploration, saved as an animation
- [x] Loss plotting
The training pipeline supports configurable hyperparameters (e.g. latent dimension, beta, batch size, epochs) through a configuration file or command-line arguments.
(Back Top)
## Getting Started
Follow these steps to set up and run the project locally.
### Prerequisites
Ensure you have Python installed (>= 3.8 recommended).
You can install it from [python.org](https://www.python.org/).
### Installation
1. **Clone the repository**
```bash
git clone https://github.com/yassa9/doodleVAE.git
cd doodleVAE
```
2. **Install dependencies**
You can install everything with pip:
```bash
pip install torch torchvision matplotlib numpy
```
3. **Prepare your dataset**
- Provide a file `.npy` path using `--data-path`.
- You can get data from [Quick, Draw!](https://github.com/googlecreativelab/quickdraw-dataset).
### Training
To train the model, run:
```bash
python train.py --data-path path/to/.npy
```
You can customize training with command-line arguments:
```bash
python train.py --data-path cat.npy --epochs 50 --latent-dim 20 --beta 4
```
| Argument | Description |
|----------------|-----------------------------------------|
| `--epochs` | Number of training epochs |
| `--batch-size` | Training batch size |
| `--latent-dim` | Dimensionality of latent space |
| `--beta` | Beta value for KL divergence term |
| `--lr` | Learning rate |
| `--save-dir` | Directory to save model and plots |
| `--no-explore` | Skip final latent interpolation animation |
(Back Top)
[product-screenshot]: images/gifshot.gif
[python]: https://img.shields.io/badge/python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54
[python-url]: https://www.python.org/
[pytorch]: https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white
[pytorch-url]: https://pytorch.org/
[numpy]: https://img.shields.io/badge/numpy-%23013243.svg?style=for-the-badge&logo=numpy&logoColor=white
[numpy-url]: https://numpy.org/
[matplotlib]: https://img.shields.io/badge/Matplotlib-%23ffffff.svg?style=for-the-badge&logo=Matplotlib&logoColor=black
[matplotlib-url]: https://matplotlib.org/
[opencv]: https://img.shields.io/badge/opencv-%23white.svg?style=for-the-badge&logo=opencv&logoColor=white
[opencv-url]: https://opencv.org/