https://github.com/hr0nix/optax-adan
An implementation of adan optimizer for optax
https://github.com/hr0nix/optax-adan
deeplearning jax optax optimization-algorithms optimization-methods
Last synced: 4 days ago
JSON representation
An implementation of adan optimizer for optax
- Host: GitHub
- URL: https://github.com/hr0nix/optax-adan
- Owner: hr0nix
- License: apache-2.0
- Created: 2022-08-23T11:15:26.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2022-09-18T15:46:01.000Z (over 3 years ago)
- Last Synced: 2025-10-30T07:52:51.257Z (3 months ago)
- Topics: deeplearning, jax, optax, optimization-algorithms, optimization-methods
- Language: Python
- Homepage:
- Size: 18.6 KB
- Stars: 7
- Watchers: 1
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# optax-adan
An implementation of adan optimizer for [optax](https://github.com/deepmind/optax/) based on [Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models](https://arxiv.org/abs/2208.06677).
Collab with usage example can be found [here](https://colab.research.google.com/drive/19--gju3ELQ9qPbDZbE4NmEnGBLJC901x?usp=sharing).
## How to use:
Install the package:
```bash
python3 -m pip install optax-adan
```
Import the optimizer:
```python3
from optax_adan import adan
```
Use it as you would use any other optimizer from optax:
```python3
# init
optimizer = adan(learning_rate=0.01)
optimizer_state = optimizer.init(initial_params)
# step
grad = grad_func(params)
updates, optimizer_state = optimizer.update(grad, optimizer_state, params)
params = optax.apply_updates(params, updates)
```