{"id":13578094,"url":"https://github.com/fadel/pytorch_ema","last_synced_at":"2025-05-15T22:07:52.612Z","repository":{"id":37724506,"uuid":"175882737","full_name":"fadel/pytorch_ema","owner":"fadel","description":"Tiny PyTorch library for maintaining a moving average of a collection of parameters.","archived":false,"fork":false,"pushed_at":"2024-10-02T07:50:49.000Z","size":28,"stargazers_count":428,"open_issues_count":5,"forks_count":26,"subscribers_count":4,"default_branch":"master","last_synced_at":"2025-04-13T04:58:03.374Z","etag":null,"topics":["deep-learning","neural-networks","pytorch"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/fadel.png","metadata":{"files":{"readme":"README.md","changelog":"CHANGELOG.md","contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2019-03-15T19:54:50.000Z","updated_at":"2025-04-04T09:45:21.000Z","dependencies_parsed_at":"2025-01-24T04:34:42.025Z","dependency_job_id":null,"html_url":"https://github.com/fadel/pytorch_ema","commit_stats":null,"previous_names":[],"tags_count":2,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fadel%2Fpytorch_ema","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fadel%2Fpytorch_ema/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fadel%2Fpytorch_ema/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fadel%2Fpytorch_ema/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/fadel","download_url":"https://codeload.github.com/fadel/pytorch_ema/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248665761,"owners_count":21142123,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","neural-networks","pytorch"],"created_at":"2024-08-01T15:01:27.338Z","updated_at":"2025-04-13T04:58:11.077Z","avatar_url":"https://github.com/fadel.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"# pytorch_ema\n\nA small library for computing exponential moving averages of model\nparameters.\n\nThis library was originally written for personal use. Nevertheless, if you run into issues\nor have suggestions for improvement, feel free to open either a new issue or\npull request.\n\n## Installation\nFor the stable version from PyPI:\n```bash\npip install torch-ema\n```\n\nFor the latest GitHub version:\n```\npip install -U git+https://github.com/fadel/pytorch_ema\n```\n\n## Usage\n\n### Example\n\n```python\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_ema import ExponentialMovingAverage\n\ntorch.manual_seed(0)\nx_train = torch.rand((100, 10))\ny_train = torch.rand(100).round().long()\nx_val = torch.rand((100, 10))\ny_val = torch.rand(100).round().long()\nmodel = torch.nn.Linear(10, 2)\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\nema = ExponentialMovingAverage(model.parameters(), decay=0.995)\n\n# Train for a few epochs\nmodel.train()\nfor _ in range(20):\n    logits = model(x_train)\n    loss = F.cross_entropy(logits, y_train)\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n    # Update the moving average with the new parameters from the last optimizer step\n    ema.update()\n\n# Validation: original\nmodel.eval()\nlogits = model(x_val)\nloss = F.cross_entropy(logits, y_val)\nprint(loss.item())\n\n# Validation: with EMA\n# the .average_parameters() context manager\n# (1) saves original parameters before replacing with EMA version\n# (2) copies EMA parameters to model\n# (3) after exiting the `with`, restore original parameters to resume training later\nwith ema.average_parameters():\n    logits = model(x_val)\n    loss = F.cross_entropy(logits, y_val)\n    print(loss.item())\n```\n\n### Manual validation mode\n\nWhile the `average_parameters()` context manager is convenient, you can also manually execute the same series of operations:\n```python\nema.store()\nema.copy_to()\n# ...\nema.restore()\n```\n\n### Custom parameters\n\nBy default the methods of `ExponentialMovingAverage` act on the model parameters the object was constructed with, but any compatible iterable of parameters can be passed to any method (such as `store()`, `copy_to()`, `update()`, `restore()`, and `average_parameters()`):\n```python\nmodel = torch.nn.Linear(10, 2)\nmodel2 = torch.nn.Linear(10, 2)\nema = ExponentialMovingAverage(model.parameters(), decay=0.995)\n# train\n# calling `ema.update()` will use `model.parameters()`\nema.copy_to(model2)\n# model2 now contains the averaged weights\n```\n\n### Resuming training\n\nLike a PyTorch optimizer, `ExponentialMovingAverage` objects have `state_dict()`/`load_state_dict()` methods to allow pausing, serializing, and restarting training without losing shadow parameters, stored parameters, or the update count.\n\n### GPU/device support\n\n`ExponentialMovingAverage` objects have a `.to()` function (like `torch.Tensor`) that can move the object's internal state to a different device or floating-point dtype.\n\n\nFor more details on individual methods, please check the docstrings.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ffadel%2Fpytorch_ema","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ffadel%2Fpytorch_ema","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ffadel%2Fpytorch_ema/lists"}