{"id":13653386,"url":"https://github.com/davda54/sam","last_synced_at":"2025-05-15T16:03:59.266Z","repository":{"id":37623256,"uuid":"307189275","full_name":"davda54/sam","owner":"davda54","description":"SAM: Sharpness-Aware Minimization (PyTorch)","archived":false,"fork":false,"pushed_at":"2024-02-21T12:34:27.000Z","size":657,"stargazers_count":1852,"open_issues_count":3,"forks_count":203,"subscribers_count":11,"default_branch":"main","last_synced_at":"2025-03-31T20:07:28.188Z","etag":null,"topics":["optimizer","pytorch","sam","sharpness-aware"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/davda54.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2020-10-25T20:46:29.000Z","updated_at":"2025-03-31T07:52:55.000Z","dependencies_parsed_at":"2024-11-30T07:01:06.877Z","dependency_job_id":"43f38401-fadf-42fa-88bf-10a6c5a02d4f","html_url":"https://github.com/davda54/sam","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/davda54%2Fsam","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/davda54%2Fsam/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/davda54%2Fsam/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/davda54%2Fsam/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/davda54","download_url":"https://codeload.github.com/davda54/sam/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247730069,"owners_count":20986404,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["optimizer","pytorch","sam","sharpness-aware"],"created_at":"2024-08-02T02:01:09.689Z","updated_at":"2025-04-07T21:12:18.953Z","avatar_url":"https://github.com/davda54.png","language":"Python","readme":"\u003ch1 align=\"center\"\u003e\u003cb\u003e(Adaptive) SAM Optimizer\u003c/b\u003e\u003c/h1\u003e\n\u003ch3 align=\"center\"\u003e\u003cb\u003eSharpness-Aware Minimization for Efficiently Improving Generalization\u003c/b\u003e\u003c/h3\u003e\n\u003cp align=\"center\"\u003e\n  \u003ci\u003e~ in Pytorch ~\u003c/i\u003e\n\u003c/p\u003e \n \n--------------\n\n\u003cbr\u003e\n\nSAM simultaneously minimizes loss value and loss sharpness. In particular, it seeks parameters that lie in **neighborhoods having uniformly low loss**. SAM improves model generalization and yields [SoTA performance for several datasets](https://paperswithcode.com/paper/sharpness-aware-minimization-for-efficiently-1). Additionally, it provides robustness to label noise on par with that provided by SoTA procedures that specifically target learning with noisy labels.\n\nThis is an **unofficial** repository for [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412) and [ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks](https://arxiv.org/abs/2102.11600). Implementation-wise, SAM class is a light wrapper that computes the regularized \"sharpness-aware\" gradient, which is used by the underlying optimizer (such as SGD with momentum). This repository also includes a simple [WRN for Cifar10](example); as a proof-of-concept, it beats the performance of SGD with momentum on this dataset.\n\n\u003cp align=\"center\"\u003e\n  \u003cimg src=\"img/loss_landscape.png\" alt=\"Loss landscape with and without SAM\" width=\"512\"/\u003e  \n\u003c/p\u003e\n\n\u003cp align=\"center\"\u003e\n  \u003csub\u003e\u003cem\u003eResNet loss landscape at the end of training with and without SAM. Sharpness-aware updates lead to a significantly wider minimum, which then leads to better generalization properties.\u003c/em\u003e\u003c/sub\u003e\n\u003c/p\u003e\n\n\u003cbr\u003e\n\n## Usage\n\nIt should be straightforward to use SAM in your training pipeline. Just keep in mind that the training will run twice as slow, because SAM needs two forward-backward passes to estime the \"sharpness-aware\" gradient. If you're using gradient clipping, make sure to change only the magnitude of gradients, not their direction.\n\n```python\nfrom sam import SAM\n...\n\nmodel = YourModel()\nbase_optimizer = torch.optim.SGD  # define an optimizer for the \"sharpness-aware\" update\noptimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)\n...\n\nfor input, output in data:\n\n  # first forward-backward pass\n  loss = loss_function(output, model(input))  # use this loss for any training statistics\n  loss.backward()\n  optimizer.first_step(zero_grad=True)\n  \n  # second forward-backward pass\n  loss_function(output, model(input)).backward()  # make sure to do a full forward pass\n  optimizer.second_step(zero_grad=True)\n...\n```\n\n\u003cbr\u003e\n\n**Alternative usage with a single closure-based `step` function**. This alternative offers similar API to native PyTorch optimizers like LBFGS (kindly suggested by [@rmcavoy](https://github.com/rmcavoy)):\n\n```python\nfrom sam import SAM\n...\n\nmodel = YourModel()\nbase_optimizer = torch.optim.SGD  # define an optimizer for the \"sharpness-aware\" update\noptimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)\n...\n\nfor input, output in data:\n  def closure():\n    loss = loss_function(output, model(input))\n    loss.backward()\n    return loss\n\n  loss = loss_function(output, model(input))\n  loss.backward()\n  optimizer.step(closure)\n  optimizer.zero_grad()\n...\n```\n\n### Training tips\n- [@hjq133](https://github.com/hjq133): The suggested usage can potentially cause problems if you use batch normalization. The running statistics are computed in both forward passes, but they should be computed only for the first one. A possible solution is to set BN momentum to zero (kindly suggested by [@ahmdtaha](https://github.com/ahmdtaha)) to bypass the running statistics during the second pass. An example usage is on lines [51](https://github.com/davda54/sam/blob/cdcbdc1574022d3a3c3240da136378c38562d51d/example/train.py#L51) and [58](https://github.com/davda54/sam/blob/cdcbdc1574022d3a3c3240da136378c38562d51d/example/train.py#L58) in [example/train.py](https://github.com/davda54/sam/blob/cdcbdc1574022d3a3c3240da136378c38562d51d/example/train.py):\n```python\nfor batch in dataset.train:\n  inputs, targets = (b.to(device) for b in batch)\n\n  # first forward-backward step\n  enable_running_stats(model)  # \u003c- this is the important line\n  predictions = model(inputs)\n  loss = smooth_crossentropy(predictions, targets)\n  loss.mean().backward()\n  optimizer.first_step(zero_grad=True)\n\n  # second forward-backward step\n  disable_running_stats(model)  # \u003c- this is the important line\n  smooth_crossentropy(model(inputs), targets).mean().backward()\n  optimizer.second_step(zero_grad=True)\n```\n\n- [@evanatyourservice](https://github.com/evanatyourservice): If you plan to train on multiple GPUs, the paper states that *\"To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update.\"* This can be achieved by the following code:\n```python\nfor input, output in data:\n  # first forward-backward pass\n  loss = loss_function(output, model(input))\n  with model.no_sync():  # \u003c- this is the important line\n    loss.backward()\n  optimizer.first_step(zero_grad=True)\n  \n  # second forward-backward pass\n  loss_function(output, model(input)).backward()\n  optimizer.second_step(zero_grad=True)\n```\n- [@evanatyourservice](https://github.com/evanatyourservice): Adaptive SAM reportedly performs better than the original SAM. The ASAM paper suggests to use higher `rho` for the adaptive updates (~10x larger)\n\n- [@mlaves](https://github.com/mlaves): LR scheduling should be either applied to the base optimizer or you should use SAM with a single `step` call (with a closure):\n```python\nscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=200)\n```\n- [@AlbertoSabater](https://github.com/AlbertoSabater): Integration with Pytorch Lightning — you can write the `training_step` function as:\n```python\ndef training_step(self, batch, batch_idx):\n    optimizer = self.optimizers()\n\n    # first forward-backward pass\n    loss_1 = self.compute_loss(batch)\n    self.manual_backward(loss_1, optimizer)\n    optimizer.first_step(zero_grad=True)\n\n    # second forward-backward pass\n    loss_2 = self.compute_loss(batch)\n    self.manual_backward(loss_2, optimizer)\n    optimizer.second_step(zero_grad=True)\n\n    return loss_1\n```\n\u003cbr\u003e\n\n\n## Documentation\n\n#### `SAM.__init__`\n\n| **Argument**    | **Description** |\n| :-------------- | :-------------- |\n| `params` (iterable) | iterable of parameters to optimize or dicts defining parameter groups |\n| `base_optimizer` (torch.optim.Optimizer) | underlying optimizer that does the \"sharpness-aware\" update |\n| `rho` (float, optional)           | size of the neighborhood for computing the max loss *(default: 0.05)* |\n| `adaptive` (bool, optional)       | set this argument to True if you want to use an experimental implementation of element-wise Adaptive SAM *(default: False)* |\n| `**kwargs` | keyword arguments passed to the `__init__` method of `base_optimizer` |\n\n\u003cbr\u003e\n\n#### `SAM.first_step`\n\nPerforms the first optimization step that finds the weights with the highest loss in the local `rho`-neighborhood.\n\n| **Argument**    | **Description** |\n| :-------------- | :-------------- |\n| `zero_grad` (bool, optional) | set to True if you want to automatically zero-out all gradients after this step *(default: False)* |\n\n\u003cbr\u003e\n\n#### `SAM.second_step`\n\nPerforms the second optimization step that updates the original weights with the gradient from the (locally) highest point in the loss landscape.\n\n| **Argument**    | **Description** |\n| :-------------- | :-------------- |\n| `zero_grad` (bool, optional) | set to True if you want to automatically zero-out all gradients after this step *(default: False)* |\n\n\u003cbr\u003e\n\n#### `SAM.step`\n\nPerforms both optimization steps in a single call. This function is an alternative to explicitly calling `SAM.first_step` and `SAM.second_step`.\n\n| **Argument**    | **Description** |\n| :-------------- | :-------------- |\n| `closure` (callable) | the closure should do an additional full forward and backward pass on the optimized model *(default: None)* |\n\n\n\n\n\u003cbr\u003e\n\n## Experiments\n\nI've verified that SAM works on a simple WRN 16-8 model run on CIFAR10; you can replicate the experiment by running [train.py](example/train.py). The Wide-ResNet is enhanced only by label smoothing and the most basic image augmentations with cutout, so the errors are higher than those in the [SAM paper](https://arxiv.org/abs/2010.01412). Theoretically, you can get even lower errors by running for longer (1800 epochs instead of 200), because SAM shouldn't be as prone to overfitting. SAM uses `rho=0.05`, while ASAM is set to `rho=2.0`, as [suggested by its authors](https://github.com/davda54/sam/issues/37).\n\n| Optimizer             | Test error rate |\n| :-------------------- |   -----: |\n| SGD + momentum        |   3.20 % |\n| SAM + SGD + momentum  |   2.86 % |\n| ASAM + SGD + momentum |   2.55 % |\n\n\n\u003cbr\u003e\n\n## Cite\n\nPlease cite the original authors if you use this optimizer in your work:\n\n```bibtex\n@inproceedings{foret2021sharpnessaware,\n  title={Sharpness-aware Minimization for Efficiently Improving Generalization},\n  author={Pierre Foret and Ariel Kleiner and Hossein Mobahi and Behnam Neyshabur},\n  booktitle={International Conference on Learning Representations},\n  year={2021},\n  url={https://openreview.net/forum?id=6Tm1mposlrM}\n}\n```\n\n```bibtex\n@inproceesings{pmlr-v139-kwon21b,\n  title={ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks},\n  author={Kwon, Jungmin and Kim, Jeongseop and Park, Hyunseo and Choi, In Kwon},\n  booktitle ={Proceedings of the 38th International Conference on Machine Learning},\n  pages={5905--5914},\n  year={2021},\n  editor={Meila, Marina and Zhang, Tong},\n  volume={139},\n  series={Proceedings of Machine Learning Research},\n  month={18--24 Jul},\n  publisher ={PMLR},\n  pdf={http://proceedings.mlr.press/v139/kwon21b/kwon21b.pdf},\n  url={https://proceedings.mlr.press/v139/kwon21b.html},\n  abstract={Recently, learning algorithms motivated from sharpness of loss surface as an effective measure of generalization gap have shown state-of-the-art performances. Nevertheless, sharpness defined in a rigid region with a fixed radius, has a drawback in sensitivity to parameter re-scaling which leaves the loss unaffected, leading to weakening of the connection between sharpness and generalization gap. In this paper, we introduce the concept of adaptive sharpness which is scale-invariant and propose the corresponding generalization bound. We suggest a novel learning method, adaptive sharpness-aware minimization (ASAM), utilizing the proposed generalization bound. Experimental results in various benchmark datasets show that ASAM contributes to significant improvement of model generalization performance.}\n}\n```\n","funding_links":[],"categories":["Python"],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdavda54%2Fsam","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdavda54%2Fsam","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdavda54%2Fsam/lists"}