https://github.com/eleutherai/optax-galore
Adds GaLore style projection wrappers to optax optimizers
https://github.com/eleutherai/optax-galore
Last synced: about 1 year ago
JSON representation
Adds GaLore style projection wrappers to optax optimizers
- Host: GitHub
- URL: https://github.com/eleutherai/optax-galore
- Owner: EleutherAI
- License: mit
- Created: 2024-08-19T19:43:03.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2024-10-03T13:58:59.000Z (over 1 year ago)
- Last Synced: 2024-11-10T14:16:57.092Z (over 1 year ago)
- Language: Python
- Homepage:
- Size: 21.5 KB
- Stars: 3
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Optax-GaLore
Optax-GaLore is an implementation of the Gradient Low-Rank Projection (GaLore) algorithm for memory-efficient training of Large Language Models (LLMs). This project extends the Optax optimization library with GaLore functionality.
## Features
- Memory-efficient optimization for large-scale models
- Compatible with existing Optax optimizers
- Flexible projection specifications for different layer types
- Support for convolutional layers and other multi-dimensional tensors
## Installation
To install Optax-GaLore, clone this repository and install the required dependencies:
```bash
git clone https://github.com/EleutherAI/optax-galore.git
cd optax-galore
pip install -r requirements.txt
```
## Usage
Here's a basic example of how to use Optax-GaLore in your project:
```python
import jax
import optax
import optax_galore.optax_galore as og
# Define your model and loss function
# ...
# Create a GaLore optimizer
learning_rate = 0.001
rank = 64
subspace_change_freq = 1000
optimizer = og.galore(
learning_rate=learning_rate,
rank=rank,
subspace_change_freq=subspace_change_freq
)
# Initialize optimizer state
opt_state = optimizer.init(params)
# Define the loss function
def loss_fn(params, batch):
# Your model's loss calculation
return loss
# Define the update function
@jax.jit
def update(params, opt_state, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss
# In your training loop:
for batch in data_loader:
params, opt_state, loss = update(params, opt_state, batch)
```
This updated usage example demonstrates how to create a jitted update function that includes the loss calculation, gradient computation, optimizer update, and parameter update. Using a jitted update function enables the compiler to optimize out the unprojected gradients to save memory (probably).
### Advanced Usage
For more control over the projection dimensions, you can use the `dimension_pytree` parameter:
```python
import optax_galore.optax_galore as og
from optax_galore.optax_galore import ProjectionSpec
dimension_pytree = {
'conv1': {'w': ProjectionSpec(2, 3), 'b': None},
'conv2': {'w': ProjectionSpec(2, 3), 'b': None}
}
optimizer = optax_galore.galore(
learning_rate=learning_rate,
rank=rank,
subspace_change_freq=subspace_change_freq,
dimension_pytree=dimension_pytree
)
```
You can also wrap other Optax optimizers with GaLore:
```python
base_optimizer = optax.adam(learning_rate=0.001)
galore_optimizer = optax_galore.galore_wrapper(
base_optimizer,
rank=64,
subspace_change_freq=1000
)
```
## Testing
To run the tests, use pytest:
```bash
pytest tests/
```
## Contributing
Contributions to Optax-GaLore are welcome! Please feel free to submit a Pull Request.