Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/dydjw9/Efficient_SAM
https://github.com/dydjw9/Efficient_SAM
Last synced: 7 days ago
JSON representation
- Host: GitHub
- URL: https://github.com/dydjw9/Efficient_SAM
- Owner: dydjw9
- Created: 2021-10-06T03:56:43.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2023-02-13T08:28:12.000Z (over 1 year ago)
- Last Synced: 2024-07-10T09:53:19.953Z (4 months ago)
- Language: Python
- Size: 29.3 KB
- Stars: 57
- Watchers: 1
- Forks: 4
- Open Issues: 6
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Efficient Sharpness-aware Minimization for Improved Training of Neural Networks
Code for [“Efficient Sharpness-aware Minimization for Improved Training of Neural Networks”](https://openreview.net/forum?id=n0OeTdNRG0Q), which has been accepted by ICLR 2022.
## Requisite
This code is implemented in PyTorch, and we have tested the code under the following environment settings:
- python = 3.8.8
- torch = 1.8.0
- torchvision = 0.9.0## What is in this repository
Codes for our ESAM on CIFAR10/CIFAR100 datasets.
## How to use it
```
from utils.layer_dp_sam import ESAM
base_optimizer = torch.optim.SGD(model.parameters(),lr=args.learning_rate,momentum=0.9,weight_decay=args.weight_decay)
optimizer = ESAM(paras, base_optimizer, rho=args.rho, weight_dropout=args.weight_dropout,adaptive=args.isASAM,nograd_cutoff=args.nograd_cutoff,opt_dropout = args.opt_dropout,temperature=args.temperature)
```--beta the SWP hyperparameter
--gamma the SDS hyperparameter
During training
loss_fct should have reduction="none", to return instance-wise losses.
defined_backward is the function used for DDP and mixed precision backward```
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def defined_backward():
if args.fp16:
with amp.scale_loss(loss, optimizer0) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()paras = [inputs,targets,loss_fct,model,defined_backward]
optimizer.paras = paras
optimizer.step()
predictions_logits,loss = optimizer.returnthings
```## Example
```bash run.sh```
## Reference Code
[1] [SAM](https://github.com/davda54/sam)