Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/uformer-pytorch
Implementation of Uformer, Attention-based Unet, in Pytorch
https://github.com/lucidrains/uformer-pytorch
artificial-intelligence deep-learning image-segmentation transformer unet
Last synced: about 11 hours ago
JSON representation
Implementation of Uformer, Attention-based Unet, in Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/uformer-pytorch
- Owner: lucidrains
- License: mit
- Created: 2021-06-17T00:56:03.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2021-10-26T15:28:20.000Z (almost 3 years ago)
- Last Synced: 2024-09-07T03:56:40.893Z (26 days ago)
- Topics: artificial-intelligence, deep-learning, image-segmentation, transformer, unet
- Language: Python
- Homepage:
- Size: 86.9 KB
- Stars: 92
- Watchers: 2
- Forks: 16
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Uformer - Pytorch
Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection.
This repository will be geared towards use in a project for learning protein structures. Specifically, it will include the ability to condition on time steps (needed for DDPM), as well as 2d relative positional encoding using rotary embeddings (instead of the bias on the attention matrix in the paper).
## Install
```bash
$ pip install uformer-pytorch
```## Usage
```python
import torch
from uformer_pytorch import Uformermodel = Uformer(
dim = 64, # initial dimensions after input projection, which increases by 2x each stage
stages = 4, # number of stages
num_blocks = 2, # number of transformer blocks per stage
window_size = 16, # set window size (along one side) for which to do the attention within
dim_head = 64,
heads = 8,
ff_mult = 4
)x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 3, 256, 256)
```To condition on time for DDPM training
```python
import torch
from uformer_pytorch import Uformermodel = Uformer(
dim = 64,
stages = 4,
num_blocks = 2,
window_size = 16,
dim_head = 64,
heads = 8,
ff_mult = 4,
time_emb = True # set this to true
)x = torch.randn(1, 3, 256, 256)
time = torch.arange(1)
pred = model(x, time = time) # (1, 3, 256, 256)
```## Citations
```bibtex
@misc{wang2021uformer,
title = {Uformer: A General U-Shaped Transformer for Image Restoration},
author = {Zhendong Wang and Xiaodong Cun and Jianmin Bao and Jianzhuang Liu},
year = {2021},
eprint = {2106.03106},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```