Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/ponder-transformer
Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper
https://github.com/lucidrains/ponder-transformer
adaptive-computation-time artificial-intelligence deep-learning transformers
Last synced: 19 days ago
JSON representation
Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper
- Host: GitHub
- URL: https://github.com/lucidrains/ponder-transformer
- Owner: lucidrains
- License: mit
- Created: 2021-08-25T12:06:59.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2021-10-30T03:31:24.000Z (about 3 years ago)
- Last Synced: 2024-10-23T11:47:06.051Z (28 days ago)
- Topics: adaptive-computation-time, artificial-intelligence, deep-learning, transformers
- Language: Python
- Homepage:
- Size: 16.6 KB
- Stars: 78
- Watchers: 6
- Forks: 8
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Ponder(ing) Transformer
Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.
This repository would not have been possible without repeated viewings of Yannic's educational video
## Install
```bash
$ pip install ponder-transformer
```## Usage
```python
import torch
from ponder_transformer import PonderTransformermodel = PonderTransformer(
num_tokens = 20000,
dim = 512,
max_seq_len = 512
)mask = torch.ones(1, 512).bool()
x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))loss = model(x, labels = y, mask = mask)
loss.backward()
```Now you can set the model to `.eval()` mode and it will terminate early when all samples of the batch have emitted a halting signal
```python
import torch
from ponder_transformer import PonderTransformermodel = PonderTransformer(
num_tokens = 20000,
dim = 512,
max_seq_len = 512,
causal = True
)x = torch.randint(0, 20000, (2, 512))
mask = torch.ones(2, 512).bool()model.eval() # setting to eval makes it return the logits as well as the halting indices
logits, layer_indices = model(x, mask = mask) # (2, 512, 20000), (2)
# layer indices will contain, for each batch element, which layer they exited
```## Citations
```bibtex
@misc{banino2021pondernet,
title = {PonderNet: Learning to Ponder},
author = {Andrea Banino and Jan Balaguer and Charles Blundell},
year = {2021},
eprint = {2107.05407},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```