Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/moskomule/sam.pytorch
A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization
https://github.com/moskomule/sam.pytorch
optimizer pytorch sam
Last synced: 6 days ago
JSON representation
A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization
- Host: GitHub
- URL: https://github.com/moskomule/sam.pytorch
- Owner: moskomule
- License: mit
- Created: 2020-12-30T01:09:23.000Z (almost 4 years ago)
- Default Branch: main
- Last Pushed: 2021-03-16T03:05:13.000Z (over 3 years ago)
- Last Synced: 2024-08-02T15:36:48.069Z (3 months ago)
- Topics: optimizer, pytorch, sam
- Language: Python
- Homepage:
- Size: 11.7 KB
- Stars: 132
- Watchers: 2
- Forks: 9
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# sam.pytorch
A PyTorch implementation of *Sharpness-Aware Minimization for Efficiently Improving Generalization* (
Foret+2020) [Paper](https://arxiv.org/abs/2010.01412), [Official implementation](https://github.com/google-research/sam)
.## Requirements
* Python>=3.8
* PyTorch>=1.7.1To run the example, you further need
* `homura` by `pip install -U homura-core==2020.12.0`
* `chika` by `pip install -U chika`## Example
```commandline
python cifar10.py [--optim.name {sam,sgd}] [--model {renst20, wrn28_2}] [--optim.rho 0.05]
```### Results: Test Accuracy (CIFAR-10)
Model | SAM | SGD |
--- | --- | --- |
ResNet-20 | 93.5| 93.2|
WRN28-2 | 95.8| 95.4|
ResNeXT29 | 96.4| 95.8|SAM needs double forward passes per each update, thus training with SAM is slower than training with SGD. In case of
ResNet-20 training, 80 mins vs 50 mins on my environment. Additional options `--use_amp --jit_model` may slightly
accelerates the training.## Usage
`SAMSGD` can be used as a drop-in replacement of PyTorch optimizers by using a closure as follows. Also, it is compatible
with `lr_scheduler` and has `state_dict` and `load_state_dict`. Currently, this implementation does not support multiple parameter groups.```python
from sam import SAMSGDoptimizer = SAMSGD(model.parameters(), lr=1e-1, rho=0.05)
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_f(output, target)
loss.backward()
return lossloss = optimizer.step(closure)
```## Citation
```bibtex
@ARTICLE{2020arXiv201001412F,
author = {{Foret}, Pierre and {Kleiner}, Ariel and {Mobahi}, Hossein and {Neyshabur}, Behnam},
title = "{Sharpness-Aware Minimization for Efficiently Improving Generalization}",
year = 2020,
eid = {arXiv:2010.01412},
eprint = {2010.01412},
}@software{sampytorch
author = {Ryuichiro Hataya},
titile = {sam.pytorch},
url = {https://github.com/moskomule/sam.pytorch},
year = {2020}
}
```