Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/minyoungg/overparam


https://github.com/minyoungg/overparam

Last synced: 3 months ago
JSON representation

Awesome Lists containing this project

README

        

# Overparam layers
PyTorch linear over-parameterization layers with automatic graph reduction.

Official codebase used in:

**The Low-Rank Simplicity Bias in Deep Networks**
[Minyoung Huh](http://minyounghuh.com/)   [Hossein Mobahi]()   [Richard Zhang](https://richzhang.github.io/)   [Brian Cheung]()   [Pulkit Agrawal]()   [Phillip Isola]()
MIT CSAIL   Google Research   Adobe Research   MIT BCS
TMLR 2023 (arXiv 2021).
**[[project page]](https://minyoungg.github.io/overparam/) | [[paper]](https://openreview.net/pdf?id=bCiNWDmlY2) | [[arXiv]](https://arxiv.org/abs/2103.10427)**

## 1. Installation
Developed on
- Python 3.7 :snake:
- PyTorch 1.7 :fire:

```bash
> git clone https://github.com/minyoungg/overparam
> cd overparam
> pip install .
```

## 2. Usage
The layers work exactly the same as any `torch.nn` layers.

### Getting started

#### (1a) OverparamLinear layer (equivalence: `nn.Linear`)

```python
from overparam import OverparamLinear

layer = OverparamLinear(16, 32, width=1, depth=2)
x = torch.randn(1, 16)
```

#### (1b) OverparamConv2d layer (equivalence: `nn.Conv2d`)

```python
from overparam import OverparamConv2d
import numpy as np
```

We can construct 3 Conv2d layers with kernel dimensions of `5x5`, `3x3`, `1x1`
```python
# Same padding
padding = max((np.sum(kernel_sizes) - len(kernel_sizes) + 1) // 2, 0)

layer = OverparamConv2d(2, 4, kernel_sizes=[5, 3, 1], padding, depth=len(kernel_sizes))

# Get the effective kernel size
print(layer.kernel_size)
```
When `kernel_sizes` is an integer, all proceeding layers are assumed to have kernel size of `1x1`.

#### (2) Forward computation

```python
# Forward pass (expanded form)
layer.train()
y = layer(x)
```

When calling `eval()` the model will automatically reduce the computation graph to its effective single-layer counterpart.
Forward pass in `eval` mode will use the effective weights instead.

```python
# Forward pass (collapsed form) [automatic]
layer.eval()
y = layer(x)
```

You can access the effective weights as follows:

```python
print(layer.weight)
print(layer.bias)
```

#### (3) Automatic conversion

```python
import torchvision.models as models
from overparam.utils import overparameterize

model = models.alexnet() # Replace this with YOUR_PYTORCH_MODEL()
model = overparameterize(model, depth=2)
```

#### (4) Batch-norm and Residual connections
We also provide support for batch-norm and linear residual connections.

- batch-normalization (pseudo-linera layer: linear during `eval` mode)
```python
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
batch_norm=True)
```

- residual-connection
```python
# every 2 layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
residual=True, residual_intervals=2)
```

- multiple residual connection
```python
# every modulo [1, 2, 3] layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
residual=True, residual_intervals=[1, 2, 3])
```

- batch-norm and residual connection
```python
# mimics `BasicBlock` in ResNets
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
batch_norm=True, residual=True, residual_intervals=2)
```

### 3. Cite
```
@article{huh2023simplicitybias,
title={The Low-Rank Simplicity Bias in Deep Networks},
author={Minyoung Huh and Hossein Mobahi and Richard Zhang and Brian Cheung and Pulkit Agrawal and Phillip Isola},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=bCiNWDmlY2},
}
```