Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/jessefarebro/flax-mup
Maximal Update Parametrization (μP) with Flax & Optax.
https://github.com/jessefarebro/flax-mup
flax jax optax
Last synced: 30 days ago
JSON representation
Maximal Update Parametrization (μP) with Flax & Optax.
- Host: GitHub
- URL: https://github.com/jessefarebro/flax-mup
- Owner: JesseFarebro
- License: mit
- Created: 2023-01-16T20:12:07.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2023-12-27T21:12:54.000Z (11 months ago)
- Last Synced: 2023-12-28T00:13:30.006Z (11 months ago)
- Topics: flax, jax, optax
- Language: Python
- Homepage:
- Size: 400 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE.md
Awesome Lists containing this project
README
# Flax Maximal Update Parametrization (μP)
This is an implementation of the [Maximal Update Parametrization (μP)](https://arxiv.org/abs/2203.03466)
in Flax.To train with μP there are two steps:
1. Create a Flax module that subclasses `mup.Module` instead of the regular linen `nn.Module`.
2. Chain the μP Optax transformation `scale_adam_by_mup` or `scale_sgd_by_mup` with your current optimizer.For example,
```py
import jax
import jax.numpy as jnp
from flax import linen as nn
import flax_mup as mup
import optaxclass Model(mup.Module):
@nn.compact
def __call__(self, x: jax.Array) -> jax.Array:
# width is exposed from `mup.Module`
x = nn.Dense(128 * self.width)(x)
x = nn.relu(x)
x = nn.Dense(128 * self.width)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return x# Create your model specifying the target width with the width field.
model = Model(width=16)# Initialize the model and optionally specify the base model width.
# It defaults to 1 but you can specify any integer value less than the
# target width above.
params = model.init(
jax.random.PRNGKey(0),
jnp.zeros((16)),
base_width=1, # `base_width` defaults to 1
)optim = optax.chain(
optax.adam(learning_rate=1e-4),
mup.scale_adam_by_mup(),
)
```## Getting Started
The library is not published on PyPi, you can instead install it from GitHub:
```sh
# with pip
pip install git+https://github.com/jessefarebro/flax-mup# with pdm
pdm add git+https://github.com/jessefarebro/flax-mup
```## Assumptions
1. The weights of the model are initialized from a distribution that's scale invariant,
i.e., let X be a random variable and define $Y = X / c$ where $c$ is a non-zero constant.
The property $\text{Var}[Y] = \text{Var}[X] / c^2$ must hold. This is the case for most initialization
schemes that draw from either a uniform or gaussian distribution.
2. The learning rate scaling is the last transformation in the optimizer chain. Both the `scale_adam_by_mup`
and `scale_sgd_by_mup` transformations assume this. For any standard Optax optimizer this should be the case.## Coordinate Checks
Provided are coordinate checks to verify the implementation for various loss functions using MLPs & CNNs.
Transformers are currently untested but support should be fairly straightforward.### Cross-Entropy MLP
![Coordinate Checks](.github/assets/coord-checks-mlp-ce.png)
### Cross-Entropy CNN
![Coordinate Checks](.github/assets/coord-checks-cnn-ce.png)
### L2 Regression MLP
![Coordinate Checks](.github/assets/coord-checks-mlp-mse.png)
### L2 Regression CNN
![Coordinate Checks](.github/assets/coord-checks-cnn-mse.png)