Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/omninet-pytorch
Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch
https://github.com/lucidrains/omninet-pytorch
artificial-intelligence attention-mechanism deep-learning transformers
Last synced: 19 days ago
JSON representation
Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/omninet-pytorch
- Owner: lucidrains
- License: mit
- Created: 2021-03-02T06:07:58.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2021-03-19T18:07:30.000Z (over 3 years ago)
- Last Synced: 2024-10-15T00:16:48.881Z (about 1 month ago)
- Topics: artificial-intelligence, attention-mechanism, deep-learning, transformers
- Language: Python
- Homepage:
- Size: 130 KB
- Stars: 53
- Watchers: 4
- Forks: 5
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Omninet - Pytorch
Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch. The authors propose that we should be attending to all the tokens of the previous layers, leveraging recent efficient attention advances to achieve this goal.
## Install
```bash
$ pip install omninet-pytorch
```## Usage
```python
import torch
from omninet_pytorch import Omninetomninet = Omninet(
dim = 512, # model dimension
depth = 6, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
pool_layer_tokens_every = 3, # key to this paper - every N layers, omni attend to all tokens of all layers
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1, # feedforward dropout
feature_redraw_interval = 1000 # how often to redraw the projection matrix for omni attention net - Performer
)x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()omninet(x, mask = mask) # (1, 1024, 512)
```Causal case, just use the class `OmninetCausal`. At the moment, it isn't faithful to the paper (I am using layer axial attention with layer positional embeddings to draw up information), but will fix this once I rework the linear attention CUDA kernel.
```python
import torch
from omninet_pytorch import OmninetCausalomninet = OmninetCausal(
dim = 512, # model dimension
depth = 6, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
pool_layer_tokens_every = 3, # key to this paper - every N layers, omni attend to all tokens of all layers
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1 # feedforward dropout
)x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()omninet(x, mask = mask) # (1, 1024, 512)
```## Citations
```bibtex
@misc{tay2021omninet,
title = {OmniNet: Omnidirectional Representations from Transformers},
author = {Yi Tay and Mostafa Dehghani and Vamsi Aribandi and Jai Gupta and Philip Pham and Zhen Qin and Dara Bahri and Da-Cheng Juan and Donald Metzler},
year = {2021},
eprint = {2103.01075},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```