Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/FLASH-pytorch
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
https://github.com/lucidrains/FLASH-pytorch
artificial-intelligence attention-mechanism deep-learning efficient-transformers transformers
Last synced: 2 months ago
JSON representation
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
- Host: GitHub
- URL: https://github.com/lucidrains/FLASH-pytorch
- Owner: lucidrains
- License: mit
- Created: 2022-03-28T19:28:39.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2023-09-26T00:14:09.000Z (over 1 year ago)
- Last Synced: 2024-11-13T04:03:29.554Z (2 months ago)
- Topics: artificial-intelligence, attention-mechanism, deep-learning, efficient-transformers, transformers
- Language: Python
- Homepage:
- Size: 34.2 MB
- Stars: 347
- Watchers: 9
- Forks: 24
- Open Issues: 7
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- StarryDivineSky - lucidrains/FLASH-pytorch
README
## FLASH - Pytorch
Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time
## Install
```bash
$ pip install FLASH-pytorch
```## Usage
The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.
It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.
```python
import torch
from flash_pytorch import GAUgau = GAU(
dim = 512,
query_key_dim = 128, # query / key dimension
causal = True, # autoregressive or not
expansion_factor = 2, # hidden dimension = dim * expansion_factor
laplace_attn_fn = True # new Mega paper claims this is more stable than relu squared as attention function
)x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)
```The authors then combine `GAU` with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.
This combination of the quadratic gated attention unit with grouped linear attention they named FLASH
You can also use this quite easily
```python
import torch
from flash_pytorch import FLASHflash = FLASH(
dim = 512,
group_size = 256, # group size
causal = True, # autoregressive or not
query_key_dim = 128, # query / key dimension
expansion_factor = 2., # hidden dimension = dim * expansion_factor
laplace_attn_fn = True # new Mega paper claims this is more stable than relu squared as attention function
)x = torch.randn(1, 1111, 512) # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)
```Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).
```python
import torch
from flash_pytorch import FLASHTransformermodel = FLASHTransformer(
num_tokens = 20000, # number of tokens
dim = 512, # model dimension
depth = 12, # depth
causal = True, # autoregressive or not
group_size = 256, # size of the groups
query_key_dim = 128, # dimension of queries / keys
expansion_factor = 2., # hidden dimension = dim * expansion_factor
norm_type = 'scalenorm', # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
shift_tokens = True # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)
```## Test on Autoregressive Enwik8
```bash
$ python train.py
```## Citations
```bibtex
@article{Hua2022TransformerQI,
title = {Transformer Quality in Linear Time},
author = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.10447}
}
``````bibtex
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}
``````bibtex
@inproceedings{Ma2022MegaMA,
title = {Mega: Moving Average Equipped Gated Attention},
author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
year = {2022}
}
```