https://github.com/lucidrains/st-moe-pytorch
Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
https://github.com/lucidrains/st-moe-pytorch
artificial-intelligence conditional-computation deep-learning mixture-of-experts
Last synced: about 1 year ago
JSON representation
Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/st-moe-pytorch
- Owner: lucidrains
- License: mit
- Created: 2023-03-26T15:00:16.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2024-06-17T00:48:47.000Z (almost 2 years ago)
- Last Synced: 2025-03-28T19:05:39.378Z (about 1 year ago)
- Topics: artificial-intelligence, conditional-computation, deep-learning, mixture-of-experts
- Language: Python
- Homepage:
- Size: 178 KB
- Stars: 323
- Watchers: 5
- Forks: 28
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README

## ST-MoE - Pytorch
Implementation of ST-MoE, the latest incarnation of mixture of experts after years of research at Brain, in Pytorch. Will be largely a transcription of the official Mesh Tensorflow implementation. If you have any papers you think should be added, while I have my attention on mixture of experts, please open an issue.
This should be SOTA for mixture-of-experts for autoregressive transformers. It is rumored that GPT4 is using 16 experts with top2 gating.
For non-autoregressive, would recommend going with the simpler and better Soft MoE.
## Install
```bash
$ pip install st-moe-pytorch
```
## Appreciation
- StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
- Aran Komatsuzaki for consultation on mixture-of-experts, for removal of 2-level MoE and simplifications to code
## Usage
```python
import torch
from st_moe_pytorch import MoE
moe = MoE(
dim = 512,
num_experts = 16, # increase the experts (# parameters) of your model without increasing computation
gating_top_n = 2, # default to top 2 gating, but can also be more (3 was tested in the paper with a lower threshold)
threshold_train = 0.2, # at what threshold to accept a token to be routed to second expert and beyond - 0.2 was optimal for 2 expert routing, and apparently should be lower for 3
threshold_eval = 0.2,
capacity_factor_train = 1.25, # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
capacity_factor_eval = 2., # capacity_factor_* should be set to a value >=1
balance_loss_coef = 1e-2, # multiplier on the auxiliary expert balancing auxiliary loss
router_z_loss_coef = 1e-3, # loss weight for router z-loss
)
inputs = torch.randn(4, 1024, 512)
out, total_aux_loss, balance_loss, router_z_loss = moe(inputs) # (4, 1024, 512), (1,), (1,), (1,)
# for the entire mixture of experts block, in context of transformer
from st_moe_pytorch import SparseMoEBlock
moe_block = SparseMoEBlock(
moe,
add_ff_before = True,
add_ff_after = True
)
out, total_aux_loss, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,) (1,), (1,)
# the total auxiliary loss will need to be summed and then added to the main loss
# the other two losses are the unweighted breakdown for logging purposes
```
## Todo
- [x] add the router z-loss proposed in paper
- [x] add the geglu expert with multiplicative gating
- [x] add an entire sparse moe block, complete with rmsnorm + residual as well as the ability to specify a feedforward before or after for stability
- [x] double check equation for router z-loss for experts inner in hierarchical moe
- [x] redo all the transcribed code from google with einops, as it is not very clear
- [x] consult some MoE experts in the open source community; question why hierarchical MoE is needed, in light of results from soft-MoE
- [x] offer top-n gating generalization, as it seems top3 (with smaller threshold) can work even better
- [x] figure out if there was an error in a previous transcription - no there was not an error
- [x] allow for different thresholds for second vs third routed expert
- [x] add coordinate descent based routing
- [x] make first naive non-optimized attempt at distributed code for mixture of experts
- [ ] distributed
- [x] handle any world size less than number of experts
- [x] handle any world size greater than number of experts - for now, just have remainder machines do nothing
- [x] support variable batch sizes
- [x] support variable seq lengths
- [ ] figure out how to move assert.py to pytests
- [ ] simplify the variable sequence length test code from another folder and move in so other researchers gain confidence
- [ ] optimize
- [ ] figure out what is faster, all gather, or broadcast with async followed by barrier
- [ ] make all distributed code pluggable, for different strategies
- [ ] figure out why there is tiny error in gradients
- [ ] improvise a `Top2GatingWithCoordinateDescent` for `MoE` without `importance`
## Citations
```bibtex
@inproceedings{Zoph2022STMoEDS,
title = {ST-MoE: Designing Stable and Transferable Sparse Expert Models},
author = {Barret Zoph and Irwan Bello and Sameer Kumar and Nan Du and Yanping Huang and Jeff Dean and Noam M. Shazeer and William Fedus},
year = {2022}
}
```