Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kevinzakka/recurrent-visual-attention
A PyTorch Implementation of "Recurrent Models of Visual Attention"
https://github.com/kevinzakka/recurrent-visual-attention
attention pytorch ram recurrent-attention-model recurrent-models
Last synced: 14 days ago
JSON representation
A PyTorch Implementation of "Recurrent Models of Visual Attention"
- Host: GitHub
- URL: https://github.com/kevinzakka/recurrent-visual-attention
- Owner: kevinzakka
- License: mit
- Archived: true
- Created: 2017-11-22T15:44:04.000Z (almost 7 years ago)
- Default Branch: master
- Last Pushed: 2023-02-24T04:25:37.000Z (over 1 year ago)
- Last Synced: 2024-08-01T16:35:31.351Z (3 months ago)
- Topics: attention, pytorch, ram, recurrent-attention-model, recurrent-models
- Language: Python
- Size: 20.5 MB
- Stars: 468
- Watchers: 14
- Forks: 123
- Open Issues: 18
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Recurrent Visual Attention
This is a **PyTorch** implementation of [Recurrent Models of Visual Attention](https://arxiv.org/abs/1406.6247) by *Volodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu*.
The *Recurrent Attention Model* (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.
## Model Description
In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.
- **glimpse sensor**: a retina that extracts a foveated glimpse `phi` around location `l` from an image `x`. It encodes the region around `l` at a high-resolution but uses a progressively lower resolution for pixels further from `l`, resulting in a compressed representation of the original image `x`.
- **glimpse network**: a network that combines the "what" (`phi`) and the "where" (`l`) into a glimpse feature vector w`g_t`.
- **core network**: an RNN that maintains an internal state that integrates information extracted from the history of past observations. It encodes the agent's knowledge of the environment through a state vector `h_t` that gets updated at every time step `t`.
- **location network**: uses the internal state `h_t` of the core network to produce the location coordinates `l_t` for the next time step.
- **action network**: after a fixed number of time steps, uses the internal state `h_t` of the core network to produce the final output classification `y`.## Results
I decided to tackle the `28x28` MNIST task with the RAM model containing 6 glimpses, of size `8x8`, with a scale factor of `1`.
| Model | Validation Error | Test Error |
|-------|------------------|------------|
| 6 8x8 | 1.1 | 1.21 |I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub `1%` error. I'll be updating the table above with results for the `60x60` Translated MNIST, `60x60` Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.
Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.
With the Adam optimizer, paper accuracy can be reached in ~160 epochs.
## Usage
The easiest way to start training your RAM variant is to edit the parameters in `config.py` and run the following command:
```
python main.py
```To resume training, run:
```
python main.py --resume=True
```Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:
```
python main.py --is_train=False
```## References
- [Torch Blog Post on RAM](http://torch.ch/blog/2015/09/21/rmva.html)