Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/mustafaaljadery/gemma-2B-10M
Gemma 2B with 10M context length using Infini-attention.
https://github.com/mustafaaljadery/gemma-2B-10M
Last synced: 25 days ago
JSON representation
Gemma 2B with 10M context length using Infini-attention.
- Host: GitHub
- URL: https://github.com/mustafaaljadery/gemma-2B-10M
- Owner: mustafaaljadery
- Created: 2024-05-07T00:52:21.000Z (8 months ago)
- Default Branch: main
- Last Pushed: 2024-05-12T21:16:56.000Z (7 months ago)
- Last Synced: 2024-08-08T21:16:03.892Z (4 months ago)
- Language: Python
- Homepage:
- Size: 14 MB
- Stars: 918
- Watchers: 11
- Forks: 57
- Open Issues: 9
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
- AiTreasureBox - mustafaaljadery/gemma-2B-10M - 12-20_952_0](https://img.shields.io/github/stars/mustafaaljadery/gemma-2B-10M.svg)|Gemma 2B with 10M context length using Infini-attention.| (Repos)
README
# Gemma 2B - 10M Context
Gemma 2B with recurrent local attention with context length of up to 10M. Our implementation uses **<32GB** of memory!
![Graphic of our implementation context](./images/graphic.png)
**Features:**
- 10M sequence length on Gemma 2B.
- Runs on less than 32GB of memory.
- Native inference optimized for cuda.
- Recurrent local attention for O(N) memory.## Quick Start
> **Note:** This is a very early checkpoint of the model. Only 200 steps. We plan on training for a lot more tokens!
Install the requirements:
```bash
pip install -r requirements.txt
```Install the model from huggingface - [Huggingface Model](https://huggingface.co/mustafaaljadery/gemma-10M-safetensor).
```bash
python main.py
```Change the `main.py` inference code to the specific prompt you desire.
```python
model_path = "./models/gemma-2b-10m"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GemmaForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
)prompt_text = "Summarize this harry potter book..."
with torch.no_grad():
generated_text = generate(
model, tokenizer, prompt_text, max_length=512, temperature=0.8
)print(generated_text)
```## How does this work?
The largest bottleneck (in terms of memory) for LLMs is the KV cache. It grows quadratically in vanilla multi-head attention, thus limiting the size of your sequence length.
Our approach splits the attention in local attention blocks as outlined by [InfiniAttention](https://arxiv.org/abs/2404.07143). We take those local attention blocks and apply recurrance to the local attention blocks for the final result of 10M context global atention.
A lot of the inspiration for our ideas comes from the [Transformer-XL](https://arxiv.org/abs/1901.02860) paper.
## More details
For more context about our motivations, implementation details, and the theory behind the work, check out our technical overview on [medium](https://medium.com/@akshgarg_36829/gemma-10m-technical-overview-900adc4fbeeb).## Credits
This was built by:
- [Mustafa Aljadery](https://www.maxaljadery.com/)
- [Siddharth Sharma](https://stanford.edu/~sidshr/)
- [Aksh Garg](https://www.linkedin.com/in/aksh-garg/)