Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/point-transformer-pytorch
Implementation of the Point Transformer layer, in Pytorch
https://github.com/lucidrains/point-transformer-pytorch
artificial-intelligence attention-mechanism deep-learning point-cloud
Last synced: 5 days ago
JSON representation
Implementation of the Point Transformer layer, in Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/point-transformer-pytorch
- Owner: lucidrains
- License: mit
- Created: 2020-12-18T18:27:00.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2022-02-12T15:40:46.000Z (almost 3 years ago)
- Last Synced: 2025-01-15T11:38:34.059Z (12 days ago)
- Topics: artificial-intelligence, attention-mechanism, deep-learning, point-cloud
- Language: Python
- Homepage:
- Size: 43.9 KB
- Stars: 595
- Watchers: 15
- Forks: 58
- Open Issues: 11
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Point Transformer - Pytorch
Implementation of the Point Transformer self-attention layer, in Pytorch. The simple circuit above seemed to have allowed their group to outperform all previous methods in point cloud classification and segmentation.
## Install
```bash
$ pip install point-transformer-pytorch
```## Usage
```python
import torch
from point_transformer_pytorch import PointTransformerLayerattn = PointTransformerLayer(
dim = 128,
pos_mlp_hidden_dim = 64,
attn_mlp_hidden_mult = 4
)feats = torch.randn(1, 16, 128)
pos = torch.randn(1, 16, 3)
mask = torch.ones(1, 16).bool()attn(feats, pos, mask = mask) # (1, 16, 128)
```This type of vector attention is much more expensive than the traditional one. In the paper, they used k-nearest neighbors on the points to exclude attention on faraway points. You can do the same with a single extra setting.
```python
import torch
from point_transformer_pytorch import PointTransformerLayerattn = PointTransformerLayer(
dim = 128,
pos_mlp_hidden_dim = 64,
attn_mlp_hidden_mult = 4,
num_neighbors = 16 # only the 16 nearest neighbors would be attended to for each point
)feats = torch.randn(1, 2048, 128)
pos = torch.randn(1, 2048, 3)
mask = torch.ones(1, 2048).bool()attn(feats, pos, mask = mask) # (1, 16, 128)
```## Citations
```bibtex
@misc{zhao2020point,
title={Point Transformer},
author={Hengshuang Zhao and Li Jiang and Jiaya Jia and Philip Torr and Vladlen Koltun},
year={2020},
eprint={2012.09164},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```