{"id":15601054,"url":"https://github.com/lucidrains/gradnorm-pytorch","last_synced_at":"2025-08-21T02:32:17.784Z","repository":{"id":207420527,"uuid":"719206362","full_name":"lucidrains/gradnorm-pytorch","owner":"lucidrains","description":"A practical implementation of GradNorm, Gradient Normalization for Adaptive Loss Balancing, in Pytorch","archived":false,"fork":false,"pushed_at":"2024-01-22T14:36:31.000Z","size":776,"stargazers_count":77,"open_issues_count":4,"forks_count":3,"subscribers_count":2,"default_branch":"main","last_synced_at":"2024-12-10T03:32:08.858Z","etag":null,"topics":["artificial-intelligence","deep-learning","gradient-normalization","loss-balancing"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-11-15T17:20:41.000Z","updated_at":"2024-10-14T13:03:56.000Z","dependencies_parsed_at":"2024-10-22T18:35:02.899Z","dependency_job_id":null,"html_url":"https://github.com/lucidrains/gradnorm-pytorch","commit_stats":{"total_commits":38,"total_committers":1,"mean_commits":38.0,"dds":0.0,"last_synced_commit":"4f793905161b471b0d63c3e4629d265c908e1dfd"},"previous_names":["lucidrains/gradnorm-pytorch"],"tags_count":25,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgradnorm-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgradnorm-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgradnorm-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgradnorm-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/gradnorm-pytorch/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":230479864,"owners_count":18232630,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","deep-learning","gradient-normalization","loss-balancing"],"created_at":"2024-10-03T02:13:14.512Z","updated_at":"2025-08-21T02:32:17.779Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003cimg src=\"./gradnorm.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n## GradNorm - Pytorch\n\nA practical implementation of \u003ca href=\"https://arxiv.org/abs/1711.02257\"\u003eGradNorm\u003c/a\u003e, Gradient Normalization for Adaptive Loss Balancing, in Pytorch\n\nIncreasingly starting to come across neural network architectures that require more than 3 auxiliary losses, so will build out an installable package that easily handles loss balancing in distributed setting, gradient accumulation, etc. Also open to incorporating any follow up research; just let me know in the issues.\n\nWill be dog-fooded for \u003ca href=\"http://github.com/lucidrains/audiolm-pytorch\"\u003eSoundStream\u003c/a\u003e, \u003ca href=\"https://github.com/lucidrains/magvit2-pytorch\"\u003eMagViT2\u003c/a\u003e as well as \u003ca href=\"https://github.com/lucidrains/metnet-3\"\u003eMetNet3\u003c/a\u003e\n\n## Appreciation\n\n- \u003ca href=\"https://stability.ai/\"\u003eStabilityAI\u003c/a\u003e, \u003ca href=\"https://a16z.com/supporting-the-open-source-ai-community/\"\u003eA16Z Open Source AI Grant Program\u003c/a\u003e, and \u003ca href=\"https://huggingface.co/\"\u003e🤗 Huggingface\u003c/a\u003e for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research\n\n## Install\n\n```bash\n$ pip install gradnorm-pytorch\n```\n\n## Usage\n\n```python\nimport torch\nfrom torch.optim import Adam\n\nfrom gradnorm_pytorch import (\n    GradNormLossWeighter,\n    MockNetworkWithMultipleLosses\n)\n\n# a mock network with multiple discriminator losses\n\nnetwork = MockNetworkWithMultipleLosses(\n    dim = 512,\n    num_losses = 4\n)\n\noptim = Adam(network.parameters(), lr = 3e-4)\n\n# backbone shared parameter\n\nbackbone_parameter = network.backbone[-1].weight\n\n# grad norm based loss weighter\n\nloss_weighter = GradNormLossWeighter(\n    num_losses = 4,\n    learning_rate = 1e-4,\n    restoring_force_alpha = 0.,                  # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3.\n    grad_norm_parameters = backbone_parameter\n)\n\n# mock input\n\nmock_input = torch.randn(2, 512)\nlosses, backbone_output_activations = network(mock_input)\n\n# backwards with the loss weights\n# will update on each backward based on gradnorm algorithm\n\nloss_weighter.backward(losses)\n\n# the usual\n\noptim.step()\noptim.zero_grad()\n```\n\nYou can also do it with respect to the gradients flowing through an intermediate activation, say a generated modality\n\n```python\n\n# same as above ...\n\nloss_weighter = GradNormLossWeighter(\n    num_losses = 4,\n    learning_rate = 1e-4,\n    restoring_force_alpha = 0.,\n    grad_norm_parameters = None # this is now None and the activations need to be returned on network forward and passed in on backwards\n)\n\n# mock input\n\nmock_input = torch.randn(2, 512)\nlosses, backbone_output_activations = network(mock_input)\n\n# backwards with the loss weights and backbone activations from which gradients backpropagate through from all losses\n\nloss_weighter.backward(losses, backbone_output_activations)\n\n# optimizer\n\noptim.step()\noptim.zero_grad()\n```\n\nYou can also switch it to basic static loss weighting, in case you want to run experiments against fixed weighting.\n\n```python\nloss_weighter = GradNormLossWeighter(\n    loss_weights = [1., 10., 5., 2.],\n    ...,\n    frozen = True\n)\n\n# or you can also freeze it on invoking the instance\n\nloss_weighter.backward(..., freeze = True)\n```\n\nFor use with \u003ca href=\"https://huggingface.co/\"\u003e🤗 Huggingface Accelerate\u003c/a\u003e, just pass in the `Accelerator` instance into the keyword `accelerator` on initialization\n\nex.\n\n```python\naccelerator = Accelerator()\n\nnetwork = accelerator.prepare(network)\n\nloss_weighter = GradNormLossWeighter(\n    ...,\n    accelerator = accelerator\n)\n\n# backwards will now use accelerator\n```\n\n## Todo\n\n- [x] take care of gradient accumulation\n- [ ] handle sets of loss weights\n- [ ] handle freezing of some loss weights, but not others\n- [ ] allow for a prior weighting, accounted for when calculating gradient targets\n\n## Citations\n\n```bibtex\n@article{Chen2017GradNormGN,\n    title   = {GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks},\n    author  = {Zhao Chen and Vijay Badrinarayanan and Chen-Yu Lee and Andrew Rabinovich},\n    journal = {ArXiv},\n    year    = {2017},\n    volume  = {abs/1711.02257},\n    url     = {https://api.semanticscholar.org/CorpusID:4703661}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fgradnorm-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2Fgradnorm-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fgradnorm-pytorch/lists"}