https://github.com/shreyansh26/ml-optimizers-jax
Toy implementations of some popular ML optimizers using Python/JAX
https://github.com/shreyansh26/ml-optimizers-jax
adam adam-optimizer gradient-descent jax machine-learning momentum optimization-algorithms optimizers
Last synced: 2 months ago
JSON representation
Toy implementations of some popular ML optimizers using Python/JAX
- Host: GitHub
- URL: https://github.com/shreyansh26/ml-optimizers-jax
- Owner: shreyansh26
- Created: 2021-06-20T07:47:57.000Z (about 4 years ago)
- Default Branch: master
- Last Pushed: 2021-06-20T08:08:06.000Z (about 4 years ago)
- Last Synced: 2025-03-24T18:52:36.717Z (3 months ago)
- Topics: adam, adam-optimizer, gradient-descent, jax, machine-learning, momentum, optimization-algorithms, optimizers
- Language: Python
- Homepage:
- Size: 10.7 KB
- Stars: 44
- Watchers: 2
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
## ML Optimizers from scratch using JAX
Implementations of some popular optimizers from scratch for a simple model i.e., Linear Regression on a dataset of 5 features. The goal of this project was to understand how these optimizers work under the hood and try to do a toy implementation myself. I also use a bit of JAX magic to perform the differentiation of the loss function w.r.t to the weights and the bias without explicitly writing their derivatives as a separate function. This can help to generalize this notebook for other types of loss functions as well.
[](https://www.kaggle.com/shreyansh2626/ml-optimizers-jax)
[](https://colab.research.google.com/github/shreyansh26/ML-Optimizers-JAX/blob/master/ml_optimizers.ipynb)The optimizers I have implemented are -
* Batch Gradient Descent
* Batch Gradient Descent + Momentum
* Nesterov Accelerated Momentum
* Adagrad
* RMSprop
* Adam
* Adamax
* Nadam
* AdabeliefReferences -
* https://ruder.io/optimizing-gradient-descent/
* https://theaisummer.com/optimization/