https://github.com/kyegomez/gats
Implementation of GATS from the paper: "GATS: Gather-Attend-Scatter" in pytorch and zeta
https://github.com/kyegomez/gats
ai attention attention-is-all-you-need attention-mechanism gpt4 llama ml multi-modal multi-modality multimodal open-source
Last synced: 8 months ago
JSON representation
Implementation of GATS from the paper: "GATS: Gather-Attend-Scatter" in pytorch and zeta
- Host: GitHub
- URL: https://github.com/kyegomez/gats
- Owner: kyegomez
- License: mit
- Created: 2024-01-18T05:31:34.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-11-11T09:31:31.000Z (11 months ago)
- Last Synced: 2025-01-31T09:22:36.290Z (8 months ago)
- Topics: ai, attention, attention-is-all-you-need, attention-mechanism, gpt4, llama, ml, multi-modal, multi-modality, multimodal, open-source
- Language: Python
- Homepage: https://discord.gg/GYbXvDGevY
- Size: 2.17 MB
- Stars: 8
- Watchers: 4
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[](https://discord.gg/qUtxnK2NMf)
# GATS
Implementation of GATS from the paper: "GATS: Gather-Attend-Scatter" in pytorch and zeta.## Install
`pip install gats-torch`## Usage
```python
import torch
from gats_torch import GATSBlock# Create a GATSBlock model with the following configuration:
# - Input dimension: 512
# - Number of attention heads: 8
# - Dimension of each attention head: 64
# - Dropout rate: 0.1
# - Window size: 512
# - Causal attention: True
# - Number of tokens to look backward: 1
# - Number of tokens to look forward: 0
# - Sequence length: 512 * 2
model = GATSBlock(
dim=512,
heads=8,
dim_head=64,
dropout=0.1,
window_size=512,
causal=True,
look_backward=1,
look_forward=0,
seqlen=512 * 2,
)# Create input tensors for different modalities
text = torch.randn(1, 1024, 512) # Text input tensor
img = torch.randn(1, 3, 224, 224) # Image input tensor
audio = torch.randn(1, 100) # Audio input tensor
video = torch.randn(1, 3, 16, 224, 224) # Video input tensor
mask = torch.ones(1, 2057).bool() # Mask tensor for attention# Pass the input tensors through the GATSBlock model
out = model(text, img, audio, video, mask=mask)# Print the output
print(out)```
# Citation
```bibtex
@misc{zolna2024gats,
title={GATS: Gather-Attend-Scatter},
author={Konrad Zolna and Serkan Cabi and Yutian Chen and Eric Lau and Claudio Fantacci and Jurgis Pasukonis and Jost Tobias Springenberg and Sergio Gomez Colmenarejo},
year={2024},
eprint={2401.08525},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
```# License
MIT