https://github.com/lucidrains/mlp-gpt-jax
A GPT, made only of MLPs, in Jax
https://github.com/lucidrains/mlp-gpt-jax
artificial-intelligence deep-learning jax language-model multilayer-perceptron
Last synced: 6 months ago
JSON representation
A GPT, made only of MLPs, in Jax
- Host: GitHub
- URL: https://github.com/lucidrains/mlp-gpt-jax
- Owner: lucidrains
- License: mit
- Created: 2021-05-21T19:42:52.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2021-06-23T02:23:18.000Z (over 4 years ago)
- Last Synced: 2025-04-27T06:51:57.698Z (6 months ago)
- Topics: artificial-intelligence, deep-learning, jax, language-model, multilayer-perceptron
- Language: Python
- Homepage:
- Size: 34.1 MB
- Stars: 57
- Watchers: 4
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## MLP GPT - Jax
A GPT, made only of MLPs, in Jax. The specific MLP to be used are gMLPs with the Spatial Gating Units.
Working Pytorch implementation
## Install
```bash
$ pip install mlp-gpt-jax
```## Usage
```python
from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGptmodel = TransformedMLPGpt(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 1024
)rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)
```To use the tiny attention (also made autoregressive with a causal mask), just set the `attn_dim` to the head dimension you'd like to use. `64` was recommended in the paper
```python
from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGptmodel = TransformedMLPGpt(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 1024,
attn_dim = 64 # set this to 64
)rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)
```## Citations
```bibtex
@misc{liu2021pay,
title = {Pay Attention to MLPs},
author = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year = {2021},
eprint = {2105.08050},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```