https://github.com/saitejautpala/rieoptax
Riemannian Optimization Using JAX
https://github.com/saitejautpala/rieoptax
jax manifolds optimization-algorithms
Last synced: 2 months ago
JSON representation
Riemannian Optimization Using JAX
- Host: GitHub
- URL: https://github.com/saitejautpala/rieoptax
- Owner: SaitejaUtpala
- Created: 2022-08-11T12:32:12.000Z (almost 3 years ago)
- Default Branch: master
- Last Pushed: 2023-10-30T07:04:51.000Z (over 1 year ago)
- Last Synced: 2025-03-27T15:21:30.619Z (3 months ago)
- Topics: jax, manifolds, optimization-algorithms
- Language: Jupyter Notebook
- Homepage:
- Size: 590 KB
- Stars: 48
- Watchers: 4
- Forks: 6
- Open Issues: 3
-
Metadata Files:
- Readme: readme.md
Awesome Lists containing this project
README
# Rieoptax
### Project is in Beta stage with active development and API is subject to change.
## Introduction
Rieoptax is library for Riemannian Optimization in [JAX](https://github.com/google/jax). The proposed library is mainly driven by the needs of efficient implementation of manifold-valued operations, optimization solvers and neural network layers readily compatible with GPU and even TPU processors.
### Blitz Intro to Riemannian Optimization
Riemannian optimization considers the following problem
$$\min_{w \in \mathcal{M}} f(w)$$ where $f : \mathcal{M} \rightarrow \mathbb{R}$, and $\mathcal{M}$ denotes a Riemannian manifold.
Instead of considering as a constrained problem, Riemannian optimization views it as an unconstrained problem on the manifold space. Riemannian (stochastic) gradient descent generalizes the Euclidean gradient descent with intrinsic updates on manifold, i.e., $w_{t+1} = {\rm Exp}_{w_t}(- \eta_t {\rm grad} f(w_t))$, where ${\rm grad} f(w_t)$ is the Riemannian (stochastic) gradient, ${\rm Exp}_w(\cdot)$ is the Riemannian exponential map at $w$ and $\eta_t$ is the step size.### Quick start
Two main differences between Euclidean Optimization and Riemannian Optimization is Riemannian Gradient $\text{grad} f(w)$ and Riemannian Exponential map $\text{Exp}$. Main design goal of Rieoptax is to handle above two things behind scenes and make it similar to standard optimization in [Optax](https://github.com/deepmind/optax)
For a complete example, see [notebooks](https://github.com/SaitejaUtpala/rieoptax/tree/master/notebooks) folder
## Overview
It consists of three module
1) [geometry](https://github.com/SaitejaUtpala/rieoptax/tree/master/rieoptax/geometry) : Implements several Riemannian manifolds of interest along with useful operations like Riemanian Exponential, Logarithmic and Euclidean gradient to Riemannian gradeint conversion rules
2) [mechanism](https://github.com/SaitejaUtpala/rieoptax/tree/master/rieoptax/mechanism) : Noise calibration for differentially private mechanism with manifold valued outputs
3) [optimizers](https://github.com/SaitejaUtpala/rieoptax/tree/master/rieoptax/optimizers) : Riemannian Optimization algorithms
## Installation
Currently installaion is done directly through github and it will soon be available through PyPI.
```
pip install git+https://github.com/SaitejaUtpala/rieoptax.git
```## Citing Rieoptax
Preprint availabe at https://arxiv.org/pdf/2210.04840.pdf