Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/cor3bit/somax
Stochastic Second-Order Methods in JAX
https://github.com/cor3bit/somax
gauss-newton-method hessian-free jax machine-learning natural-gradient quasi-newton second-order-optimization stochastic-optimization
Last synced: 5 days ago
JSON representation
Stochastic Second-Order Methods in JAX
- Host: GitHub
- URL: https://github.com/cor3bit/somax
- Owner: cor3bit
- License: apache-2.0
- Created: 2024-04-02T18:39:28.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2024-06-18T14:41:15.000Z (18 days ago)
- Last Synced: 2024-06-19T18:43:37.489Z (17 days ago)
- Topics: gauss-newton-method, hessian-free, jax, machine-learning, natural-gradient, quasi-newton, second-order-optimization, stochastic-optimization
- Language: Python
- Homepage:
- Size: 438 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Lists
- awesome-soms - Somax - second-order stochastic solvers (Implementation in JAX / Other)
README
Somax
![]()
Somax is a library of Second-Order Methods for stochastic optimization
written in [JAX](https://github.com/google/jax).
Somax is based on the [JAXopt](https://github.com/google/jaxopt) StochasticSolver API,
and can be used as a drop-in
replacement for JAXopt as well as
[Optax](https://github.com/google-deepmind/optax) solvers.Currently supported methods:
- Diagonal Scaling:
- [AdaHessian](https://ojs.aaai.org/index.php/AAAI/article/view/17275);
- [Sophia](https://arxiv.org/abs/2305.14342);
- Hessian-free Optimization:
- [Newton-CG](https://epubs.siam.org/doi/10.1137/10079923X);
- Quasi-Newton:
- [Stochastic Quasi-Newton Framework (SQN)](https://arxiv.org/abs/1606.04838);
- Gauss-Newton:
- [Exact Gauss-Newton (EGN)](https://arxiv.org/abs/2405.14402);
- [Stochastic Gauss-Newton (SGN)](https://arxiv.org/abs/2006.02409);
- Natural Gradient:
- [Natural Gradient with Sherman-Morrison-Woodbury formula (SWM-NG)](https://arxiv.org/abs/1906.02353).Future releases:
- Add support for separate "gradient batches"
and "curvature batches" for all solvers;
- Add support for Optax rate schedules.⚠️ Since JAXopt is currently being merged into Optax,
Somax at some point will switch to the Optax API as well.*The catfish in the logo is a nod to "сом",
the Belarusian word for "catfish", also pronounced as "som".## Installation
```bash
pip install python-somax
```Requires [JAXopt](https://github.com/google/jaxopt) 0.8.2+.
## Quick example
```py
from somax import EGN# initialize the solver
solver = EGN(
predict_fun=model.apply,
loss_type='mse',
learning_rate=0.1,
regularizer=1.0,
)# initialize the solver state
opt_state = solver.init_state(params)# run the optimization loop
for i in range(10):
params, opt_state = solver.update(params, opt_state, batch_x, batch_y)
```See more in the [examples](examples) folder.
## Citation
```bibtex
@misc{korbit2024somax,
author = {Nick Korbit},
title = {{SOMAX}: a library of second-order methods for stochastic optimization written in {JAX}},
year = {2024},
url = {https://github.com/cor3bit/somax},
}
```## See also
**Optimization with JAX**
[Optax](https://github.com/google-deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[JAXopt](https://github.com/google/jaxopt): deterministic second-order methods (e.g., Gauss-Newton, Levenberg
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.**Awesome Projects**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of various JAX projects.
[Awesome SOMs](https://github.com/cor3bit/awesome-soms): a list
of resources for second-order optimization methods in machine learning.## Acknowledgements
Some of the implementation ideas are based on the following repositories:
- Line Search in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/armijo_sgd.py#L48
- L-BFGS Inverse Hessian-Gradient product in JAXopt: https://github.com/google/jaxopt/blob/main/jaxopt/_src/lbfgs.py#L44
- AdaHessian (official implementation): https://github.com/amirgholami/adahessian
- AdaHessian (Nestor Demeure's implementation): https://github.com/nestordemeure/AdaHessianJax
- Sophia (official implementation): https://github.com/Liuhong99/Sophia
- Sophia (levanter implementation): https://github.com/stanford-crfm/levanter/blob/main/src/levanter/optim/sophia.py