https://github.com/rish-16/aft-pytorch
Unofficial PyTorch implementation of Attention Free Transformer (AFT) layers by Apple Inc.
https://github.com/rish-16/aft-pytorch
Last synced: about 1 year ago
JSON representation
Unofficial PyTorch implementation of Attention Free Transformer (AFT) layers by Apple Inc.
- Host: GitHub
- URL: https://github.com/rish-16/aft-pytorch
- Owner: rish-16
- License: mit
- Created: 2021-06-01T08:42:27.000Z (about 5 years ago)
- Default Branch: main
- Last Pushed: 2022-04-10T05:43:09.000Z (about 4 years ago)
- Last Synced: 2025-05-20T14:45:07.014Z (about 1 year ago)
- Language: Python
- Homepage:
- Size: 83 KB
- Stars: 237
- Watchers: 8
- Forks: 23
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# aft-pytorch
Unofficial PyTorch implementation of **Attention Free Transformer**'s layers by [Zhai](https://twitter.com/zhaisf?lang=en), et al. [[abs](https://openreview.net/forum?id=pW--cu2FCHY), [pdf](https://arxiv.org/pdf/2105.14103.pdf)] from Apple Inc.
> I'd like to thank primary author, Dr. Shuangfei Zhai, for his informal guidance and feedback as I built this package!

## Installation
You can install `aft-pytorch` via `pip`:
```bash
pip install aft-pytorch
```
## Usage
You can import the **AFT-Full** or **AFT-Simple** layer (as described in the paper) from the package like so:
### `AFTFull`
```python
from aft_pytorch import AFTFull
layer = AFTFull(
max_seqlen=20,
dim=512,
hidden_dim=64
)
# a batch of sequences with 10 timesteps of length 512 each
x = torch.rand(32, 10, 512)
y = layer(x) # [32, 10, 512]
```
### `AFTSimple`
```python
from aft_pytorch import AFTSimple
layer = AFTSimple(
max_seqlen=20,
dim=512,
hidden_dim=64
)
# a batch of sequences with 10 timesteps of length 512 each
x = torch.rand(32, 10, 512)
y = layer(x) # [32, 10, 512]
```
### `AFTLocal`
```python
from aft_pytorch import AFTLocal
layer = AFTLocal(
max_seqlen=20,
dim=512,
hidden_dim=64
)
# a batch of sequences with 10 timesteps of length 512 each
x = torch.rand(32, 10, 512)
y = layer(x) # [32, 10, 512]
```
> This layer wrapper is a 'plug-and-play' with your existing networks / Transformers. You can swap out the Self-Attention layer with the available layers in this package with minimal changes.
## TODO
- [ ] Add full AFT architecture
- [ ] Add variants like, `AFTConv`
- [ ] Benchmark using Karpathy's [minGPT](https://github.com/karpathy/minGPT)
## Contributing
If you like this repo, please leave a star! If there are any amends or suggestions, feel free to raise a PR/issue.
## Credits
```
@misc{attention-free-transformer,
title = {An Attention Free Transformer},
author = {Shuangfei Zhai and Walter Talbott and Nitish Srivastava and Chen Huang and Hanlin Goh and Ruixiang Zhang and Josh Susskind},
year = {2021},
URL = {https://arxiv.org/pdf/2105.14103.pdf}
}
```
## License
[MIT](https://github.com/rish-16/aft-pytorch/blob/main/LICENSE)