https://github.com/devzhk/cgds-package
Package for CGD and ACGD optimizers
https://github.com/devzhk/cgds-package
optimizers pytorch
Last synced: 23 days ago
JSON representation
Package for CGD and ACGD optimizers
- Host: GitHub
- URL: https://github.com/devzhk/cgds-package
- Owner: devzhk
- License: mit
- Created: 2020-06-27T14:15:47.000Z (almost 5 years ago)
- Default Branch: master
- Last Pushed: 2022-08-11T21:49:19.000Z (almost 3 years ago)
- Last Synced: 2025-04-30T05:43:56.544Z (23 days ago)
- Topics: optimizers, pytorch
- Language: Python
- Homepage: https://pypi.org/project/CGDs
- Size: 50.8 KB
- Stars: 20
- Watchers: 4
- Forks: 4
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# CGDs
## Overview
`CGDs` is a package implementing optimization algorithms including three variants of [CGD](https://arxiv.org/abs/1905.12103) in [Pytorch](https://pytorch.org/) with Hessian vector product and conjugate gradient.
`CGDs` is for competitive optimization problem such as generative adversarial networks (GANs) as follows:
$$
\min_{\mathbf{x}}f(\mathbf{x}, \mathbf{y}) \min_{\mathbf{y}} g(\mathbf{x}, \mathbf{y})
$$## Installation
```bash
pip3 install CGDs
```
You can also directly download the `CGDs` directory and copy it to your project.## Package description
The `CGDs` package implements the following optimization algorithms with Pytorch:
- `BCGD` : CGD algorithm in [Competitive Gradient Descent](https://arxiv.org/abs/1905.12103).
- `ACGD` : ACGD algorithm in [Implicit competitive regularization in GANs](https://arxiv.org/abs/1910.05852).
- `GACGD`: works for general-sum problem
## How to use
Quickstart with notebook: [Examples of using ACGD](https://colab.research.google.com/drive/1-52aReaBAPNBtq2NcHxKkVIbdVXdyqtH?usp=sharing).Similar to Pytorch package `torch.optim`, using optimizers in `CGDs` has two main steps: construction and update steps.
### Construction
To construct an optimizer, you have to give it two iterables containing the parameters (all should be `Variable`s).
Then you need to specify the `device`, `learning rate`s.Example:
```pythonfrom src import CGDs
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = CGDs.ACGD(max_param=model_G.parameters(), min_params=model_D.parameters(),
lr_max=1e-3, lr_min=1e-3, device=device)
optimizer = CGDs.BCGD(max_params=[var1, var2], min_params=[var3, var4, var5],
lr_max=0.01, lr_min=0.01, device=device)
```### Update step
Both two optimizers have `step()` method, which updates the parameters according to their update rules. The function can be called once the computation graph is created. You have to pass in the loss but do not have to compute gradients before `step()` , which is *different* from `torch.optim`.
Example:
```python
for data in dataset:
optimizer.zero_grad()
real_pred = model_D(data)
latent = torch.randn((batch_size, latent_dim), device=device)
fake_pred = D(G(latent))
loss = loss_fn(real_output, fake_output)
optimizer.step(loss=loss)
```
For general competitive optimization, two losses should be defined and passed to optimizer.step
```python
loss_x = loss_f(x, y)
loss_y = loss_g(x, y)
optimizer.step(loss_x, loss_y)
```
## Use with Pytorch DistributedDataParallelFor example,
```python
G = DDP(G, device_ids=[rank], broadcast_buffers=False)
D = DDP(D, device_ids=[rank], broadcast_buffers=False)
g_reducer = G.reducer
d_reducer = D.reduceroptimizer = ACGD(max_params=G.parameters(), min_params=D.parameters(),
max_reducer=g_reducer, min_reducer=d_reducer,
lr_max=1e-3, lr_min=1e-3,
tol=1e-4, atol=1e-8)
for data in dataloader:
real_pred = D(data)
latent = torch.randn((batchsize, latent_dim))
fake_img = G(latent)
fake_pred = D(fake_img)
# trigger is used to trigger the comm
trigger = real_pred[0, 0] + fake_img[0, 0, 0, 0]
loss = loss_fn(real_pred, fake_pred)
optimizer.step(loss, trigger=trigger.mean())
```## Citation
Please cite it if you find this code useful.
```latex
@misc{cgds-package,
author = {Hongkai Zheng},
title = {CGDs},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/devzhk/cgds-package}},
}
```