https://github.com/jaketae/g-mlp
PyTorch implementation of Pay Attention to MLPs
https://github.com/jaketae/g-mlp
attention image-classification mlp natural-language-processing pytorch
Last synced: 5 months ago
JSON representation
PyTorch implementation of Pay Attention to MLPs
- Host: GitHub
- URL: https://github.com/jaketae/g-mlp
- Owner: jaketae
- License: mit
- Created: 2021-05-18T18:11:11.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2021-06-28T15:37:40.000Z (over 4 years ago)
- Last Synced: 2025-05-13T02:05:25.296Z (5 months ago)
- Topics: attention, image-classification, mlp, natural-language-processing, pytorch
- Language: Python
- Homepage:
- Size: 185 KB
- Stars: 40
- Watchers: 2
- Forks: 6
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# gMLP
PyTorch implementation of [Pay Attention to MLPs](https://arxiv.org/abs/2105.08050).
![]()
## Quickstart
Clone this repository.
```
git clone https://github.com/jaketae/g-mlp.git
```Navigate to the cloned directory. You can use the barebone gMLP model via
```python
>>> from g_mlp import gMLP
>>> model = gMLP()
```By default, the model comes with the following parameters:
```python
gMLP(
d_model=256,
d_ffn=512,
seq_len=256,
num_layers=6,
)
```## Usage
The repository also contains gMLP models specifically for language modeling and image classification.
### NLP
`gMLPForLanguageModeling` shares the same default parameters as `gMLP`, with `num_tokens=10000` as an added parameter that represents the size of the token embedding table.
```python
>>> from g_mlp import gMLPForLanguageModeling
>>> model = gMLPForLanguageModeling()
>>> tokens = torch.randint(0, 10000, (8, 256))
>>> model(tokens).shape
torch.Size([8, 256, 256])
```### Computer Vision
`gMLPForImageClassification` is a ViT-esque version of gMLP that includes a patch creating layer and a final classification head.
```python
>>> from g_mlp import gMLPForImageClassification
>>> model = gMLPForImageClassification()
>>> images = torch.randn(8, 3, 256, 256)
>>> model(images).shape
torch.Size([8, 1000])
```## Summary
The authors of the paper present gMLP, an an attention-free all-MLP architecture based on spatial gating units. gMLP achieves parity with transformer models such as ViT and BERT on language and vision downstream tasks. The authors also show that gMLP scales with increased data and number of parameters, suggesting that self-attention is not a necessary component for designing performant models.
## Resources
- [Original Paper](https://arxiv.org/abs/2105.08050)
- [Phil Wang's implementation](https://github.com/lucidrains/g-mlp-pytorch)