https://github.com/tripplyons/retentive-network
A minimal PyTorch implementation of Retentive Network
https://github.com/tripplyons/retentive-network
Last synced: 9 months ago
JSON representation
A minimal PyTorch implementation of Retentive Network
- Host: GitHub
- URL: https://github.com/tripplyons/retentive-network
- Owner: tripplyons
- License: mit
- Created: 2023-09-17T22:41:09.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2023-09-17T23:25:22.000Z (almost 3 years ago)
- Last Synced: 2025-02-19T13:18:11.091Z (over 1 year ago)
- Language: Python
- Size: 4.88 KB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Retentive Network (RetNet)
A minimal PyTorch implementation of [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621)
## Notes
- This repository exists mostly for educational purposes, for both me and anyone else who wants to learn about RetNet.
- It is basically a direct translation of the math in the paper, complex numbers and all. I haven't looked into it, but there are other implementations that claim to do it without needing complex numbers.
- It makes heavy use of `torch.einsum`, so make sure you understand it before trying to understand this code.
- I haven't implemented the chunkwise recurrent mode yet, this repo only has the parallel and the recurrent modes.
## Usage
For more examples see [test.py](test.py)
```python
import torch
from retnet import RetNet
model = RetNet(256, 64, 4, 4)
x = torch.randint(0, 256, (1, 64), dtype=torch.long)
print(model.loss(x))
```