An open API service indexing awesome lists of open source software.

https://github.com/szagoruyko/binary-wide-resnet

PyTorch implementation of Wide Residual Networks with 1-bit weights by McDonnell (ICLR 2018)
https://github.com/szagoruyko/binary-wide-resnet

pytorch wide-residual-networks

Last synced: 5 months ago
JSON representation

PyTorch implementation of Wide Residual Networks with 1-bit weights by McDonnell (ICLR 2018)

Awesome Lists containing this project

README

        

1-bit Wide ResNet
===========

PyTorch implementation of training 1-bit Wide ResNets from this paper:

*Training wide residual networks for deployment using a single bit for each weight* by **Mark D. McDonnell** at ICLR 2018

The idea is very simple but surprisingly effective for training ResNets with binary weights. Here is the proposed weight parameterization as PyTorch autograd function:

```python
class ForwardSign(torch.autograd.Function):
@staticmethod
def forward(ctx, w):
return math.sqrt(2. / (w.shape[1] * w.shape[2] * w.shape[3])) * w.sign()

@staticmethod
def backward(ctx, g):
return g
```

On forward, we take sign of the weights and scale it by He-init constant. On backward, we propagate gradient without changes. WRN-20-10 trained with such parameterization is only slightly off from it's full precision variant, here is what I got myself with this code on CIFAR-100:

| network | accuracy (5 runs mean +- std) | checkpoint (Mb) |
|:---|:---:|:---:|
| WRN-20-10 | 80.5 +- 0.24 | 205 Mb |
| WRN-20-10-1bit | 80.0 +- 0.26 | 3.5 Mb |

## Details

Here are the differences with WRN code :

* BatchNorm has no affine weight and bias parameters
* First layer has 16 * width channels
* Last fc layer is removed in favor of 1x1 conv + F.avg_pool2d
* Downsample is done by F.avg_pool2d + torch.cat instead of strided conv
* SGD with cosine annealing and warm restarts

I used PyTorch 0.4.1 and Python 3.6 to run the code.

Reproduce WRN-20-10 with 1-bit training on CIFAR-100:

```bash
python main.py --binarize --save ./logs/WRN-20-10-1bit_$RANDOM --width 10 --dataset CIFAR100
```

Convergence plot (train error in dash):

download

I've also put 3.5 Mb checkpoint with binary weights packed with `np.packbits`, and a very short script to evaluate it:

```bash
python evaluate_packed.py --checkpoint wrn20-10-1bit-packed.pth.tar --width 10 --dataset CIFAR100
```

S3 url to checkpoint: