Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/berlino/gated_linear_attention
https://github.com/berlino/gated_linear_attention
Last synced: 2 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/berlino/gated_linear_attention
- Owner: berlino
- License: mit
- Created: 2023-12-11T18:13:44.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2024-03-09T23:00:36.000Z (10 months ago)
- Last Synced: 2024-10-28T08:41:29.135Z (3 months ago)
- Language: Python
- Size: 32.2 KB
- Stars: 97
- Watchers: 6
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- Awesome-state-space-models - ICML2024
README
# Gated Linear Attention Layer
Standalone module of Gated Linear Attention (GLA) from [Gated Linear Attention Transformers with
Hardware-Efficient Training](https://arxiv.org/pdf/2312.06635.pdf).```
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention
```Warning: ```fused_chunk``` mode needs Triton2.2 + CUDA12 (See [issue](https://github.com/berlino/gated_linear_attention/issues/8 )). You can use [test](https://github.com/sustcsonglin/flash-linear-attention/blob/main/tests/test_fused_chunk.py) to quickly see if you can use ```fused_chunk``` mode. If cannot, please refer to [link](https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/layers/gla.py#L44C1-L45C1) and use ```chunk``` mode instead.
## Usage
Load the checkpoint from huggingface.
```python
from gla_model import GLAForCausalLM
model = GLAForCausalLM.from_pretrained("bailin28/gla-1B-100B")
vocab_size = model.config.vocab_size
bsz, seq_len = 32, 2048
x = torch.randint(high=vocab_size, size=(bsz, seq_len))
model_output = model(x)
loss = model_output.loss
logits = model_output.logits
```