https://github.com/emalagoli92/pondernet-tensorflow
TensorFlow 2.X reimplementation of PonderNet: Learning to Ponder, Andrea Banino, Jan Balaguer, Charles Blundell.
https://github.com/emalagoli92/pondernet-tensorflow
adaptive-computation-time artficial-intelligence deep-learning pytorch recurrent-neural-network tensorflow
Last synced: 8 months ago
JSON representation
TensorFlow 2.X reimplementation of PonderNet: Learning to Ponder, Andrea Banino, Jan Balaguer, Charles Blundell.
- Host: GitHub
- URL: https://github.com/emalagoli92/pondernet-tensorflow
- Owner: EMalagoli92
- License: mit
- Created: 2021-11-02T21:09:39.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2023-01-13T18:20:19.000Z (almost 3 years ago)
- Last Synced: 2025-01-15T14:58:15.345Z (10 months ago)
- Topics: adaptive-computation-time, artficial-intelligence, deep-learning, pytorch, recurrent-neural-network, tensorflow
- Language: Python
- Homepage:
- Size: 288 KB
- Stars: 3
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README



# PonderNet - TensorFlow
TensorFlow 2.X reimplementation of [PonderNet: Learning to Ponder](https://arxiv.org/abs/2107.05407), Andrea Banino, Jan Balaguer, Charles Blundell.
## Table of contents
- [Abstract](#abstract)
- [Experiment on Parity Task](#paritytask)
- [Installation](#installation)
- [Usage](#usage)
- [Acknowledgement](#acknowledgement)
- [Citations](#citations)
- [License](#license)
## Abstract
In standard neural networks the amount of computation used grows with the size of the inputs, but not with the complexity of the problem being learnt. To overcome this limitation we introduce PonderNet, a new algorithm that learns to adapt the amount of computation based on the complexity of the problem at hand. PonderNet learns end-to-end the number of computational steps to achieve an effective compromise between training prediction accuracy, computational cost and generalization. On a complex synthetic problem, PonderNet dramatically improves performance over previous adaptive computation methods and additionally succeeds at extrapolation tests where traditional neural networks fail. Also, our method matched the current state of the art results on a real world question and answering dataset, but using less compute. Finally, PonderNet reached state of the art results on a complex task designed to test the reasoning capabilities of neural networks.
## Experiment on Parity Task
The input of the parity task is a vector with 0's 1's and −1's. The output is the parity of 1's - one if there is an odd number of 1's and zero otherwise. The input is generated by making a random number of elements in the vector either 1 or −1's.

Performance on the parity task. a) Interpolation. Top: accuracy for both PonderNet (blue) and ACT (orange). Bottom: number of ponder steps at evaluation time. Error bars are calculated over 10 random seeds. b) Extrapolation. Top: accuracy for both PonderNet (blue) and ACT (orange). Bottom: number of ponder steps at evaluation time. Error bars are calculated over 10 random seeds. c) Total number of compute steps calculated as the number of actual forward passes performed by each network. Blue is PonderNet, Green is ACT and Orange is an RNN without adaptive compute.
## Installation
Clone the repo and install necessary packages
```
git clone https://github.com/EMalagoli92/PonderNet-TensorFlow.git
pip install -r requirements.txt
```
Tested on *Ubuntu 20.04.4 LTS x86_64*, *python 3.9.7*.
## Usage
Train a PonderNet on Parity Task
```
python __main__.py
```
## Acknowledgement
[PonderNet](https://nn.labml.ai/adaptive_computation/ponder_net/index.html) (Official PyTorch Implementation)
## Citations
```bibtex
@misc{banino2021pondernet,
title={PonderNet: Learning to Ponder},
author={Andrea Banino and Jan Balaguer and Charles Blundell},
year={2021},
eprint={2107.05407},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
## License
This work is made available under the [MIT License](https://github.com/EMalagoli92/PonderNet-TensorFlow/blob/main/LICENSE)