https://github.com/ericjang/draw
TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"
https://github.com/ericjang/draw
Last synced: about 1 year ago
JSON representation
TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"
- Host: GitHub
- URL: https://github.com/ericjang/draw
- Owner: ericjang
- License: apache-2.0
- Created: 2016-02-22T16:31:06.000Z (over 10 years ago)
- Default Branch: master
- Last Pushed: 2018-08-07T08:13:15.000Z (almost 8 years ago)
- Last Synced: 2025-03-29T08:06:09.199Z (over 1 year ago)
- Language: Python
- Size: 1.61 MB
- Stars: 527
- Watchers: 22
- Forks: 155
- Open Issues: 8
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# draw
TensorFlow implementation of [DRAW: A Recurrent Neural Network For Image Generation](http://arxiv.org/pdf/1502.04623.pdf) on the MNIST generation task.
| With Attention | Without Attention |
| ------------- | ------------- |
|
|
|
Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.
For a gentle walkthrough through the paper and implementation, see the writeup here: [http://blog.evjang.com/2016/06/understanding-and-implementing.html](http://blog.evjang.com/2016/06/understanding-and-implementing.html).
## Usage
`python draw.py --data_dir=/tmp/draw` downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to `/tmp/draw/draw_data.npy`
You can visualize the results by running the script `python plot_data.py `
For example,
`python myattn /tmp/draw/draw_data.npy`
To run training without attention, do:
`python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False`
## Restoring from Pre-trained Model
Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in `draw.py` and editing the path to your checkpoint file as needed. Save electricity!
```python
saver.restore(sess, "/tmp/draw/drawmodel.ckpt")
```
This git repository contains the following pre-trained in the `data/` folder:
| Filename | Description |
| ------------- | ------------- |
| draw_data_attn.npy | Training outputs for DRAW with attention |
| drawmodel_attn.ckpt | Saved weights for DRAW with attention |
| draw_data_noattn.npy | Training outputs for DRAW without attention |
| drawmodel_noattn.ckpt | Saved weights for DRAW without attention |
These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.
## Useful Resources
- https://github.com/vivanov879/draw
- https://github.com/jbornschein/draw
- https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW (wish I had found this earlier)
- [Video Lecture on Variational Autoencoders and Image Generation]( https://www.youtube.com/watch?v=P78QYjWh5sM&list=PLE6Wd9FR--EfW8dtjAuPoTuPcqmOV53Fu&index=3)