https://github.com/evanatyourservice/flat-sophia
sophia optimizer further projected towards flat areas of loss landscape
https://github.com/evanatyourservice/flat-sophia
jax optax optimization second-order-optimization sophia
Last synced: 4 months ago
JSON representation
sophia optimizer further projected towards flat areas of loss landscape
- Host: GitHub
- URL: https://github.com/evanatyourservice/flat-sophia
- Owner: evanatyourservice
- License: mit
- Created: 2024-08-17T23:30:41.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2024-12-19T00:23:01.000Z (over 1 year ago)
- Last Synced: 2025-04-03T00:11:15.559Z (about 1 year ago)
- Topics: jax, optax, optimization, second-order-optimization, sophia
- Language: Python
- Homepage:
- Size: 269 KB
- Stars: 2
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# flat-sophia
Sophia optimizer further projected towards flat areas of loss landscape
Ideas come mainly from [this paper by Wang et al.](https://arxiv.org/abs/2405.20763)
They projected adam towards a flatter area using Hvp. Here, since sophia is already
using the Hvp, we keep a cheap int8 mask used to further project sophia's update towards
flatter areas.
## A small experiment
run_experiment.py is a sort of worst case scenerio experiment where a ViT is too
wide and shallow and is prone to overfitting.
Baseline is orange line, flat-sophia is green line. Projecting updates towards flatter
areas helped prevent overfitting and the rise in loss.


## How it works
There are two pertinent values, `sharp_fraction` and `dampening_factor`. `sharp_fraction`
is the fraction of sharpest updates that will be dampened, and `dampening_factor` is the
factor by which they'll be scaled down. The example uses `sharp_fraction=0.2` and
`dampening_factor=10`.
Whenever the preconditioner is updated, we also update the sharpness mask with the largest
`sharp_fraction` of Hvp values equal to `dampening_factor`, and the rest equal to 1.
The final update is then divided by this mask.