An open API service indexing awesome lists of open source software.

https://github.com/lucidrains/rq-transformer

Implementation of RQ Transformer, proposed in the paper "Autoregressive Image Generation using Residual Quantization"
https://github.com/lucidrains/rq-transformer

artificial-intelligence attention-mechanism deep-learning image-generation transformers

Last synced: 11 months ago
JSON representation

Implementation of RQ Transformer, proposed in the paper "Autoregressive Image Generation using Residual Quantization"

Awesome Lists containing this project

README

          

## RQ-Transformer

Implementation of RQ Transformer, which proposes a more efficient way of training multi-dimensional sequences autoregressively. This repository will only contain the transformer for now. You can use this vector quantization library for the residual VQ.

This type of axial autoregressive transformer should be compatible with memcodes, proposed in NWT. It would likely also work well with multi-headed VQ

## Install

```bash
$ pip install RQ-transformer
```

## Usage

```python
import torch
from rq_transformer import RQTransformer

model = RQTransformer(
num_tokens = 16000, # number of tokens, in the paper they had a codebook size of 16k
dim = 512, # transformer model dimension
max_spatial_seq_len = 1024, # maximum positions along space
depth_seq_len = 4, # number of positions along depth (residual quantizations in paper)
spatial_layers = 8, # number of layers for space
depth_layers = 4, # number of layers for depth
dim_head = 64, # dimension per head
heads = 8, # number of attention heads
)

x = torch.randint(0, 16000, (1, 1024, 4))

loss = model(x, return_loss = True)
loss.backward()

# then after much training

logits = model(x)

# and sample from the logits accordingly
# or you can use the generate function

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
```

I also think there is something deeper going on, and have generalized this to any number of dimensions. You can use it by importing the `HierarchicalCausalTransformer`

```python
import torch
from rq_transformer import HierarchicalCausalTransformer

model = HierarchicalCausalTransformer(
num_tokens = 16000, # number of tokens
dim = 512, # feature dimension
dim_head = 64, # dimension of attention heads
heads = 8, # number of attention heads
depth = (4, 4, 2), # 3 stages (but can be any number) - transformer of depths 4, 4, 2
max_seq_len = (16, 4, 5) # the maximum sequence length of first, stage, then the fixed sequence length of all subsequent stages
).cuda()

x = torch.randint(0, 16000, (1, 10, 4, 5)).cuda()

loss = model(x, return_loss = True)
loss.backward()

# after a lot training

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 16, 4, 5)
```

## Todo

- [ ] move hierarchical causal transformer to separate repository, seems to be working

## Citations

```bibtex
@unknown{unknown,
author = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
year = {2022},
month = {03},
title = {Autoregressive Image Generation using Residual Quantization}
}
```

```bibtex
@misc{press2021ALiBi,
title = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
author = {Ofir Press and Noah A. Smith and Mike Lewis},
year = {2021},
url = {https://ofir.io/train_short_test_long.pdf}
}
```