{"id":15600934,"url":"https://github.com/lucidrains/ema-pytorch","last_synced_at":"2025-05-14T10:06:37.946Z","repository":{"id":37279805,"uuid":"505551757","full_name":"lucidrains/ema-pytorch","owner":"lucidrains","description":"A simple way to keep track of an Exponential Moving Average (EMA) version of your Pytorch model","archived":false,"fork":false,"pushed_at":"2024-12-03T21:56:45.000Z","size":68,"stargazers_count":571,"open_issues_count":5,"forks_count":35,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-04-19T13:45:37.259Z","etag":null,"topics":["artificial-intelligence","deep-learning","exponential-moving-average"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-06-20T18:18:17.000Z","updated_at":"2025-04-18T22:23:42.000Z","dependencies_parsed_at":"2023-12-23T04:25:55.752Z","dependency_job_id":"9d244a96-4e37-4ae0-b6a1-49ca7b00601b","html_url":"https://github.com/lucidrains/ema-pytorch","commit_stats":{"total_commits":77,"total_committers":12,"mean_commits":6.416666666666667,"dds":"0.18181818181818177","last_synced_commit":"dee87fb2281a6a69088e34460d2113c9e1dcb702"},"previous_names":[],"tags_count":54,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fema-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fema-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fema-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fema-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/ema-pytorch/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254119480,"owners_count":22017951,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","deep-learning","exponential-moving-average"],"created_at":"2024-10-03T02:09:40.545Z","updated_at":"2025-05-14T10:06:37.879Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"## EMA - Pytorch\n\nA simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model\n\n## Install\n\n```bash\n$ pip install ema-pytorch\n```\n\n## Usage\n\n```python\nimport torch\nfrom ema_pytorch import EMA\n\n# your neural network as a pytorch module\n\nnet = torch.nn.Linear(512, 512)\n\n# wrap your neural network, specify the decay (beta)\n\nema = EMA(\n    net,\n    beta = 0.9999,              # exponential moving average factor\n    update_after_step = 100,    # only after this number of .update() calls will it start updating\n    update_every = 10,          # how often to actually update, to save on compute (updates every 10th .update() call)\n)\n\n# mutate your network, with SGD or otherwise\n\nwith torch.no_grad():\n    net.weight.copy_(torch.randn_like(net.weight))\n    net.bias.copy_(torch.randn_like(net.bias))\n\n# you will call the update function on your moving average wrapper\n\nema.update()\n\n# then, later on, you can invoke the EMA model the same way as your network\n\ndata = torch.randn(1, 512)\n\noutput     = net(data)\nema_output = ema(data)\n\n# if you want to save your ema model, it is recommended you save the entire wrapper\n# as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)\n# however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model\n```\n\nIn order to use the post-hoc synthesized EMA, proposed by Karras et al. in \u003ca href=\"https://arxiv.org/abs/2312.02696\"\u003ea recent paper\u003c/a\u003e, follow the example below\n\n```python\nimport torch\nfrom ema_pytorch import PostHocEMA\n\n# your neural network as a pytorch module\n\nnet = torch.nn.Linear(512, 512)\n\n# wrap your neural network, specify the sigma_rels or gammas\n\nemas = PostHocEMA(\n    net,\n    sigma_rels = (0.05, 0.28),           # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one\n    update_every = 10,                  # how often to actually update, to save on compute (updates every 10th .update() call)\n    checkpoint_every_num_steps = 10,\n    checkpoint_folder = './post-hoc-ema-checkpoints'  # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training\n)\n\nnet.train()\n\nfor _ in range(1000):\n    # mutate your network, with SGD or otherwise\n\n    with torch.no_grad():\n        net.weight.copy_(torch.randn_like(net.weight))\n        net.bias.copy_(torch.randn_like(net.bias))\n\n    # you will call the update function on your moving average wrapper\n\n    emas.update()\n\n# now that you have a few checkpoints\n# you can synthesize an EMA model with a different sigma_rel (say 0.15)\n\nsynthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)\n\n# output with synthesized EMA\n\ndata = torch.randn(1, 512)\n\nsynthesized_ema_output = synthesized_ema(data)\n\n```\n\nFor testing out the claims of a free lunch from the \u003ca href=\"https://arxiv.org/abs/2402.09240\"\u003e`Switch EMA`\u003c/a\u003e paper, just set `update_model_with_ema_every` as so\n\n```python\n\nema = EMA(\n    net,\n    ...,\n    update_model_with_ema_every = 10000 # say 10k steps is 1 epoch\n)\n\n# or you can do it manually at the end of each epoch\n\nema.update_model_with_ema()\n\n```\n\n## Citations\n\n```bibtex\n@article{Karras2023AnalyzingAI,\n    title   = {Analyzing and Improving the Training Dynamics of Diffusion Models},\n    author  = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},\n    journal = {ArXiv},\n    year    = {2023},\n    volume  = {abs/2312.02696},\n    url     = {https://api.semanticscholar.org/CorpusID:265659032}\n}\n```\n\n```bibtex\n@article{Lee2024SlowAS,\n    title   = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks},\n    author  = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle},\n    journal = {ArXiv},\n    year    = {2024},\n    volume  = {abs/2406.02596},\n    url     = {https://api.semanticscholar.org/CorpusID:270258586}\n}\n```\n\n```bibtex\n@article{Li2024SwitchEA,\n    title   = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},\n    author  = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},\n    journal = {ArXiv},\n    year    = {2024},\n    volume  = {abs/2402.09240},\n    url     = {https://api.semanticscholar.org/CorpusID:267657558}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fema-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2Fema-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fema-pytorch/lists"}