Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/google-research/hit-gan
Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).
https://github.com/google-research/hit-gan
generative-adversarial-network tensorflow vision-transformer
Last synced: 7 days ago
JSON representation
Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).
- Host: GitHub
- URL: https://github.com/google-research/hit-gan
- Owner: google-research
- License: apache-2.0
- Created: 2021-12-13T02:35:09.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2024-07-30T21:38:41.000Z (6 months ago)
- Last Synced: 2025-01-14T15:12:58.416Z (14 days ago)
- Topics: generative-adversarial-network, tensorflow, vision-transformer
- Language: Python
- Homepage: https://arxiv.org/abs/2106.07631
- Size: 43.9 KB
- Stars: 92
- Watchers: 4
- Forks: 9
- Open Issues: 5
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# [HiT-GAN](https://arxiv.org/pdf/2106.07631.pdf) Official TensorFlow Implementation
HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GANs). It achieves state-of-the-art performance for high-resolution image synthesis. Please check our NeurIPS 2021 paper "[Improved Transformer for High-Resolution GANs](https://arxiv.org/pdf/2106.07631.pdf)" for more details.
This implementation is based on TensorFlow 2.x. We use `tf.keras` layers for building the model and use `tf.data` for our input pipeline. The model is trained using a custom training loop with `tf.distribute` on multiple TPUs/GPUs.
## Environment setup
It is recommended to run distributed training to train our model with TPUs and evaluate it with GPUs. The code is compatible with TensorFlow 2.x. See requirements.txt for all prerequisites, and you can also install them using the following command.
```
pip install -r requirements.txt
```## ImageNet
At the first time, download ImageNet following `tensorflow_datasets` instruction from the [official guide](https://www.tensorflow.org/datasets/catalog/imagenet2012).
### Train on ImageNet
To pretrain the model on ImageNet with Cloud TPUs, first check out the [Google Cloud TPU tutorial](https://cloud.google.com/tpu/docs/tutorials/mnist) for basic information on how to use Google Cloud TPUs.
Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for [tensorflow_datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012), please set the following enviroment variables:
```
TPU_NAME=
STORAGE_BUCKET=gs://
DATA_DIR=$STORAGE_BUCKET/
MODEL_DIR=$STORAGE_BUCKET/
```The following command can be used to train a model on ImageNet (which reflects the default hyperparameters in our paper) on TPUv2 4x4:
```
python run.py --mode=train --dataset=imagenet2012 \
--train_batch_size=256 --train_steps=1000000 \
--image_crop_size=128 --image_crop_proportion=0.875 \
--save_every_n_steps=2000 \
--latent_dim=256 --generator_lr=0.0001 \
--discriminator_lr=0.0001 --channel_multiplier=1 \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=True --master=$TPU_NAME
```To train the model on ImageNet with multiple GPUs, try the following command:
```
python run.py --mode=train --dataset=imagenet2012 \
--train_batch_size=256 --train_steps=1000000 \
--image_crop_size=128 --image_crop_proportion=0.875 \
--save_every_n_steps=2000 \
--latent_dim=256 --generator_lr=0.0001 \
--discriminator_lr=0.0001 --channel_multiplier=1 \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=False --use_ema_model=False
```Please set `train_batch_size` according to the number of GPUs for training. __Note that storing Exponential Moving Average (EMA) models is not supported with GPUs currently (`--use_ema_model=False`), so training with GPUs will lead to slight performance drop.__
### Evaluate on ImageNet
Run the following command to evaluate the model on GPUs:
```
python run.py --mode=eval --dataset=imagenet2012 \
--eval_batch_size=128 --train_steps=1000000 \
--image_crop_size=128 --image_crop_proportion=0.875 \
--latent_dim=256 --channel_multiplier=1 \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=False --use_ema_model=True
```This command runs models with 8 P100 GPUs. Please set `eval_batch_size` according to the number of GPUs for evaluation. Please also note that `train_steps` and `use_ema_model` should be set according to the values used for training.
## CelebA-HQ
At the first time, download CelebA-HQ following `tensorflow_datasets` instruction from the [official guide](https://www.tensorflow.org/datasets/catalog/celeb_a_hq).
### Train on CelebA-HQ
The following command can be used to train a model on CelebA-HQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:
```
python run.py --mode=train --dataset=celeb_a_hq/256 \
--train_batch_size=256 --train_steps=250000 \
--image_crop_size=256 --image_crop_proportion=1.0 \
--save_every_n_steps=1000 \
--latent_dim=512 --generator_lr=0.00005 \
--discriminator_lr=0.00005 --channel_multiplier=2 \
--use_consistency_regularization=True \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=True --master=$TPU_NAME
```### Evaluate on CelebA-HQ
Run the following command to evaluate the model on 8 P100 GPUs:
```
python run.py --mode=eval --dataset=celeb_a_hq/256 \
--eval_batch_size=128 --train_steps=250000 \
--image_crop_size=256 --image_crop_proportion=1.0 \
--latent_dim=512 --channel_multiplier=2 \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=False --use_ema_model=True
```## FFHQ
At the first time, download the tfrecords of FFHQ from the [official site](https://github.com/NVlabs/ffhq-dataset) and put them into `$DATA_DIR`.
### Train on FFHQ
The following command can be used to train a model on FFHQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:
```
python run.py --mode=train --dataset=ffhq/256 \
--train_batch_size=256 --train_steps=500000 \
--image_crop_size=256 --image_crop_proportion=1.0 \
--save_every_n_steps=1000 \
--latent_dim=512 --generator_lr=0.00005 \
--discriminator_lr=0.00005 --channel_multiplier=2 \
--use_consistency_regularization=True \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=True --master=$TPU_NAME
```### Evaluate on FFHQ
Run the following command to evaluate the model on 8 P100 GPUs:
```
python run.py --mode=eval --dataset=ffhq/256 \
--eval_batch_size=128 --train_steps=500000 \
--image_crop_size=256 --image_crop_proportion=1.0 \
--latent_dim=512 --channel_multiplier=2 \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=False --use_ema_model=True
```## Cite
```
@inproceedings{zhao2021improved,
title = {Improved Transformer for High-Resolution {GANs}},
author = {Long Zhao and Zizhao Zhang and Ting Chen and Dimitris Metaxas and Han Zhang},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2021}
}
```## Disclaimer
This is not an officially supported Google product.