An open API service indexing awesome lists of open source software.

https://github.com/notedance/pytorch-pcgrad

PyTorch Implementation of "Gradient Surgery for Multi-Task Learning" using multiprocessing
https://github.com/notedance/pytorch-pcgrad

deep-learning deep-reinforcement-learning multi-task-learning multi-task-reinforcement-learning multi-task-rl optimizer pytorch reinforcement-learning

Last synced: 8 months ago
JSON representation

PyTorch Implementation of "Gradient Surgery for Multi-Task Learning" using multiprocessing

Awesome Lists containing this project

README

          

# Pytorch-PCGrad
PyTorch Implementation of "Gradient Surgery for Multi-Task Learning" using multiprocessing

# Usage
```python
import torch
import torch.nn as nn
import torch.optim as optim
from ppcgrad import PPCGrad

# wrap your favorite optimizer
optimizer = PPCGrad(optim.Adam(net.parameters()))
losses = [...] # a list of per-task losses
assert len(losses) == num_tasks
optimizer.pc_backward(losses) # calculate the gradient can apply gradient modification
optimizer.step() # apply gradient step
```