https://github.com/lucidrains/deformable-attention
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
https://github.com/lucidrains/deformable-attention
artificial-intelligence attention-mechanism deep-learning
Last synced: 5 months ago
JSON representation
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
- Host: GitHub
- URL: https://github.com/lucidrains/deformable-attention
- Owner: lucidrains
- License: mit
- Created: 2022-03-17T04:35:19.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2025-02-03T21:34:57.000Z (8 months ago)
- Last Synced: 2025-04-14T20:57:48.463Z (6 months ago)
- Topics: artificial-intelligence, attention-mechanism, deep-learning
- Language: Python
- Homepage:
- Size: 146 KB
- Stars: 334
- Watchers: 8
- Forks: 33
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Deformable Attention
Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DETR. The relative positional embedding has also been modified for better extrapolation, using the Continuous Positional Embedding proposed in SwinV2.
## Install
```bash
$ pip install deformable-attention
```## Usage
```python
import torch
from deformable_attention import DeformableAttentionattn = DeformableAttention(
dim = 512, # feature dimensions
dim_head = 64, # dimension per head
heads = 8, # attention heads
dropout = 0., # dropout
downsample_factor = 4, # downsample factor (r in paper)
offset_scale = 4, # scale of offset, maximum offset
offset_groups = None, # number of offset groups, should be multiple of heads
offset_kernel_size = 6, # offset kernel size
)x = torch.randn(1, 512, 64, 64)
attn(x) # (1, 512, 64, 64)
```3d deformable attention
```python
import torch
from deformable_attention import DeformableAttention3Dattn = DeformableAttention3D(
dim = 512, # feature dimensions
dim_head = 64, # dimension per head
heads = 8, # attention heads
dropout = 0., # dropout
downsample_factor = (2, 8, 8), # downsample factor (r in paper)
offset_scale = (2, 8, 8), # scale of offset, maximum offset
offset_kernel_size = (4, 10, 10), # offset kernel size
)x = torch.randn(1, 512, 10, 32, 32) # (batch, dimension, frames, height, width)
attn(x) # (1, 512, 10, 32, 32)
```1d deformable attention for good measure
```python
import torch
from deformable_attention import DeformableAttention1Dattn = DeformableAttention1D(
dim = 128,
downsample_factor = 4,
offset_scale = 2,
offset_kernel_size = 6
)x = torch.randn(1, 128, 512)
attn(x) # (1, 128, 512)
```## Citation
```bibtex
@misc{xia2022vision,
title = {Vision Transformer with Deformable Attention},
author = {Zhuofan Xia and Xuran Pan and Shiji Song and Li Erran Li and Gao Huang},
year = {2022},
eprint = {2201.00520},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
``````bibtex
@misc{liu2021swin,
title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```