https://github.com/google/jaxopt
Hardware accelerated, batchable and differentiable optimizers in JAX.
https://github.com/google/jaxopt
bi-level deep-learning differentiable-programming jax optimization
Last synced: 15 days ago
JSON representation
Hardware accelerated, batchable and differentiable optimizers in JAX.
- Host: GitHub
- URL: https://github.com/google/jaxopt
- Owner: google
- License: apache-2.0
- Created: 2021-07-12T17:16:53.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2025-03-18T13:34:44.000Z (8 months ago)
- Last Synced: 2025-03-18T14:31:21.034Z (8 months ago)
- Topics: bi-level, deep-learning, differentiable-programming, jax, optimization
- Language: Python
- Homepage: https://jaxopt.github.io
- Size: 3.29 MB
- Stars: 955
- Watchers: 17
- Forks: 68
- Open Issues: 141
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
- awesome-soms - JAXopt - deterministic second-order methods (e.g., Gauss-Newton, Levenberg Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD (Implementation in JAX / Other)
- awesome-jax - JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. <img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center"> (Libraries / New Libraries)
- trackawesomelist - JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. <img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center"> (Recently Updated / [Feb 16, 2025](/content/2025/02/16/README.md))
README
# JAXopt
[**Status**](#status)
| [**Installation**](#installation)
| [**Documentation**](https://jaxopt.github.io)
| [**Examples**](https://github.com/google/jaxopt/tree/main/examples)
| [**Cite us**](#citeus)
Hardware accelerated, batchable and differentiable optimizers in
[JAX](https://github.com/google/jax).
- **Hardware accelerated:** our implementations run on GPU and TPU, in addition
to CPU.
- **Batchable:** multiple instances of the same optimization problem can be
automatically vectorized using JAX's vmap.
- **Differentiable:** optimization problem solutions can be differentiated with
respect to their inputs either implicitly or via autodiff of unrolled
algorithm iterations.
JAXopt is no longer maintained nor developed. Alternatives may be found on the
JAX [website](https://docs.jax.dev/en/latest/). Some of its features (like
losses, projections, lbfgs optimizer) have been ported into
[optax](https://github.com/google-deepmind/optax). We are sincerely grateful for
all the community contributions the project has garnered over the years.
To install the latest release of JAXopt, use the following command:
```bash
$ pip install jaxopt
```
To install the **development** version, use the following command instead:
```bash
$ pip install git+https://github.com/google/jaxopt
```
Alternatively, it can be installed from sources with the following command:
```bash
$ python setup.py install
```
Our implicit differentiation framework is described in this
[paper](https://arxiv.org/abs/2105.15183). To cite it:
```
@article{jaxopt_implicit_diff,
title={Efficient and Modular Implicit Differentiation},
author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy
and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian
and Vert, Jean-Philippe},
journal={arXiv preprint arXiv:2105.15183},
year={2021}
}
```
## Disclaimer
JAXopt was an open source project maintained by a dedicated team in Google
Research. It is not an official Google product.