https://github.com/inventwithdean/wgan-gp
Implementation of Wassertein GAN with Gradient Penalty in PyTorch
https://github.com/inventwithdean/wgan-gp
generative-adversarial-network generative-ai wasserstein-gan wgan-gp
Last synced: 7 months ago
JSON representation
Implementation of Wassertein GAN with Gradient Penalty in PyTorch
- Host: GitHub
- URL: https://github.com/inventwithdean/wgan-gp
- Owner: inventwithdean
- Created: 2024-08-13T13:22:14.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2024-08-13T13:33:14.000Z (about 1 year ago)
- Last Synced: 2025-01-10T23:20:11.942Z (9 months ago)
- Topics: generative-adversarial-network, generative-ai, wasserstein-gan, wgan-gp
- Language: Python
- Homepage:
- Size: 5.18 MB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# WGAN-GP on 50k Celeba Dataset
This repository contains the implementation of a Wasserstein GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset. The model is designed to generate 64x64 images of celebrity faces, using a deep convolutional architecture for both the generator and the critic.
## Directory Structure
- **checkpoints/**: Contains the saved model checkpoints during training.
- **output/**: Stores the generated images at every training epochs.
- **critic.py**: Defines the architecture of the critic network.
- **dataset_loader.py**: Handles loading and preprocessing of the CelebA dataset.
- **generator.py**: Defines the architecture of the generator network.
- **image_saver.py**: Utility to save generated images during training.
- **showcase.png**: A sample image showcasing the results of the model after training.
- **train_wdcgan.py**: Script to train the WGAN-GP model.## Prerequisites
- Python 3.9-3.11
- PyTorch
- Matplotlib
- torchvision
- tqdm## Dataset
The CelebA dataset is not included in this repository. You can download the dataset from [Kaggle](https://www.kaggle.com/datasets/therealcyberlord/50k-celeba-dataset-64x64). Once downloaded, ensure that the images are resized to 64x64 pixels before training.## Training the Model
To train the WGAN-GP model, run the following command:
```bash
python train_wdcgan.py
```You can modify the hyperparameters like `N_EPOCHS`, `BATCH_SIZE`, and `C_LAMBDA` according to your needs.
## Outputs
- **checkpoints/**: The model checkpoints are saved here for every few epochs.
- **output/**: This directory contains the generated images at various stages of training.## Showcase
Below is a sample output from the trained model:

## Acknowledgments
This implementation is based on the original WGAN-GP paper: ["Improved Training of Wasserstein GANs"](https://arxiv.org/abs/1704.00028).
## License
This project is licensed under the MIT License.