Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/memformer
Implementation of Memformer, a Memory-augmented Transformer, in Pytorch
https://github.com/lucidrains/memformer
artificial-intelligence attention-mechanism memory transformers
Last synced: 5 days ago
JSON representation
Implementation of Memformer, a Memory-augmented Transformer, in Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/memformer
- Owner: lucidrains
- License: mit
- Created: 2020-10-10T04:44:16.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2020-11-13T20:16:41.000Z (about 4 years ago)
- Last Synced: 2024-11-07T23:44:24.240Z (13 days ago)
- Topics: artificial-intelligence, attention-mechanism, memory, transformers
- Language: Python
- Homepage:
- Size: 126 KB
- Stars: 106
- Watchers: 7
- Forks: 10
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Memformer - Pytorch
Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attention, learned efficiently through Memory-Replay BackPropagation (MRBP) through time.
## Install
```bash
$ pip install memformer
```## Usage
Full encoder / decoder, as in the paper
```python
import torch
from memformer import Memformermodel = Memformer(
dim = 512,
enc_num_tokens = 256,
enc_depth = 2,
enc_heads = 8,
enc_max_seq_len = 1024,
dec_num_tokens = 256,
dec_depth = 2,
dec_heads = 8,
dec_max_seq_len = 1024,
num_memory_slots = 128
)src_seg_1 = torch.randint(0, 256, (1, 1024))
src_seg_2 = torch.randint(0, 256, (1, 1024))
src_seg_3 = torch.randint(0, 256, (1, 1024))tgt = torch.randint(0, 256, (1, 1024))
enc_out1, mems1, _ = model(src_seg_1) # (1, 1024, 512), (1, 128, 512), _
enc_out2, mems2, _ = model(src_seg_2, mems = mems1)
enc_out3, mems3, loss = model(src_seg_3, tgt, mems = mems2)loss.backward()
```Encoder only
```python
import torch
from memformer import Memformermodel = Memformer(
dim = 512,
enc_num_tokens = 256,
enc_heads = 8,
enc_depth = 2,
enc_max_seq_len = 1024,
num_memory_slots = 128,
num_mem_updates = 2,
encoder_only = True # only use encoder, in which output is encoded output
)src1 = torch.randint(0, 256, (1, 1024))
src2 = torch.randint(0, 256, (1, 1024))enc1, mems1 = model(src1) # (1, 1024, 512), (1, 128, 512)
enc2, mems2 = model(src2, mems = mems1)
```Memory Replay Back-Propagation
```python
import torch
from memformer import Memformer, memory_replay_backpropmodel = Memformer(
dim = 512,
num_memory_slots = 128,
enc_num_tokens = 256,
enc_depth = 2,
enc_max_seq_len = 1024,
dec_num_tokens = 256,
dec_depth = 2,
dec_max_seq_len = 1024
).cuda()seq = torch.randint(0, 256, (1, 8192)).cuda()
seq_mask = torch.ones_like(seq).bool().cuda()tgt = torch.randint(0, 256, (1, 512)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()# will automatically split the source sequence to 8 segments
memory_replay_backprop(
model,
src = seq,
tgt = tgt,
src_mask = seq_mask,
tgt_mask = tgt_mask
)
```## Citations
```bibtex
@inproceedings{
anonymous2021memformer,
title={Memformer: The Memory-Augmented Transformer},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=_adSMszz_g9},
note={under review}
}
```