https://github.com/lucidrains/isab-pytorch
An implementation of (Induced) Set Attention Block, from the Set Transformers paper
https://github.com/lucidrains/isab-pytorch
artificial-intelligence attention attention-mechanism deep-learning
Last synced: about 1 year ago
JSON representation
An implementation of (Induced) Set Attention Block, from the Set Transformers paper
- Host: GitHub
- URL: https://github.com/lucidrains/isab-pytorch
- Owner: lucidrains
- License: mit
- Created: 2020-10-26T20:00:39.000Z (over 5 years ago)
- Default Branch: main
- Last Pushed: 2023-01-10T21:02:50.000Z (over 3 years ago)
- Last Synced: 2024-05-02T01:14:21.599Z (about 2 years ago)
- Topics: artificial-intelligence, attention, attention-mechanism, deep-learning
- Language: Python
- Homepage:
- Size: 54.7 KB
- Stars: 53
- Watchers: 6
- Forks: 5
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README

## Induced Set Attention Block (ISAB) - Pytorch
A concise implementation of (Induced) Set Attention Block, from the Set Transformers paper. It proposes to reduce attention from O(n²) to O(mn), where m is the number of inducing points (learned latents).
Update: Interesting enough, a new paper has used the ISAB block successfully, in the domain of denoising diffusion for efficient generation of images and video.
## Install
```bash
$ pip install isab-pytorch
```
## Usage
You can either set the number of latents, in which the parameters will be instantiated and returned on completion of cross attention.
```python
import torch
from isab_pytorch import ISAB
attn = ISAB(
dim = 512,
heads = 8,
num_latents = 128,
latent_self_attend = True
)
seq = torch.randn(1, 16384, 512) # (batch, seq, dim)
mask = torch.ones((1, 16384)).bool()
out, latents = attn(seq, mask = mask) # (1, 16384, 512), (1, 128, 512)
```
Or you can choose not to set the number of latents, and pass in the latents yourself (some persistent latent that propagates down the transformer, as an example)
```python
import torch
from isab_pytorch import ISAB
attn = ISAB(
dim = 512,
heads = 8
)
seq = torch.randn(1, 16384, 512) # (batch, seq, dim)
latents = torch.nn.Parameter(torch.randn(128, 512)) # some memory, passed through multiple ISABs
out, new_latents = attn(seq, latents) # (1, 16384, 512), (1, 128, 512)
```
## Citations
```bibtex
@misc{lee2019set,
title = {Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks},
author = {Juho Lee and Yoonho Lee and Jungtaek Kim and Adam R. Kosiorek and Seungjin Choi and Yee Whye Teh},
year = {2019},
eprint = {1810.00825},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
```bibtex
@article{Alayrac2022Flamingo,
title = {Flamingo: a Visual Language Model for Few-Shot Learning},
author = {Jean-Baptiste Alayrac et al},
year = {2022}
}
```