https://github.com/lucidrains/memory-compressed-attention
Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"
https://github.com/lucidrains/memory-compressed-attention
artificial-intelligence attention-mechanism deep-learning
Last synced: 10 months ago
JSON representation
Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"
- Host: GitHub
- URL: https://github.com/lucidrains/memory-compressed-attention
- Owner: lucidrains
- License: mit
- Created: 2020-07-25T20:23:12.000Z (almost 6 years ago)
- Default Branch: master
- Last Pushed: 2023-04-10T03:39:39.000Z (about 3 years ago)
- Last Synced: 2025-08-25T20:13:12.628Z (10 months ago)
- Topics: artificial-intelligence, attention-mechanism, deep-learning
- Language: Python
- Homepage:
- Size: 48.8 KB
- Stars: 70
- Watchers: 2
- Forks: 12
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README

## Memory Compressed Attention
Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.
The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.
## Install
```bash
$ pip install memory_compressed_attention
```
## Usage
```python
import torch
from memory_compressed_attention import MemoryCompressedAttention
attn = MemoryCompressedAttention(
dim = 512,
heads = 8, # number of heads
causal = False, # auto-regressive or not
compression_factor = 3, # compression ratio
dropout = 0.1 # dropout post-attention
)
x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()
attn(x, input_mask = mask) # (1, 1024, 512)
```
## Citations
```bibtex
@misc{liu2018generating,
title={Generating Wikipedia by Summarizing Long Sequences},
author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
year={2018},
eprint={1801.10198},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```