Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/activatedgeek/optax-swag
Stochastic Weight Averaging (SWA) transforms for Optax with JAX
https://github.com/activatedgeek/optax-swag
jax optax python stochastic-weight-averaging swag
Last synced: 30 days ago
JSON representation
Stochastic Weight Averaging (SWA) transforms for Optax with JAX
- Host: GitHub
- URL: https://github.com/activatedgeek/optax-swag
- Owner: activatedgeek
- License: apache-2.0
- Created: 2023-03-08T21:12:14.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2023-03-29T00:02:17.000Z (over 1 year ago)
- Last Synced: 2024-08-31T11:50:50.009Z (2 months ago)
- Topics: jax, optax, python, stochastic-weight-averaging, swag
- Language: Python
- Homepage:
- Size: 19.5 KB
- Stars: 2
- Watchers: 3
- Forks: 1
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# SWAG in Optax
[![PyPI version](https://badge.fury.io/py/optax-swag.svg)](https://pypi.org/project/optax-swag/)
This package implements [SWAG](https://arxiv.org/abs/1902.02476)
as an [Optax](https://optax.readthedocs.io/) transform to allow
usage with [JAX](https://jax.readthedocs.io/).## Installation
Install from `pip` as:
```shell
pip install optax-swag
```To install the latest directly from source, run
```shell
pip install git+https://github.com/activatedgeek/optax-swag.git
```## Usage
To start updating the iterate statistics, use [chaining](https://optax.readthedocs.io/en/latest/api.html#chain) as
```python
import optax
from optax_swag import swagoptimizer = optax.chain(
... ## Other optimizer and transform config.
swag(freq, rank) ## Always add as the last transform.
)
```The [SWAGState](./optax_swag/state.py#L22) object can be accessed from
the optimizer state list for downstream usage.### Sampling
A reference code to generate samples from the collected statistics
is provided below.```python
import jax
import jax.numpy as jnpfrom optax_swag import sample_swag
swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
jnp.array(samples_rng), swa_opt_state)
```The resulting `swag_sample_params` can now be used for downstream evaluation.
**NOTE**: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.
# License
Apache 2.0