Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/Oulu-IMEDS/pytorch_bn_fusion
Batch normalization fusion for PyTorch
https://github.com/Oulu-IMEDS/pytorch_bn_fusion
batch-normalization deep-learning deep-neural-networks inference-optimization pytorch
Last synced: 29 days ago
JSON representation
Batch normalization fusion for PyTorch
- Host: GitHub
- URL: https://github.com/Oulu-IMEDS/pytorch_bn_fusion
- Owner: Oulu-IMEDS
- License: mit
- Created: 2018-07-24T08:09:01.000Z (over 6 years ago)
- Default Branch: master
- Last Pushed: 2020-04-06T07:31:05.000Z (over 4 years ago)
- Last Synced: 2024-08-28T17:12:52.197Z (4 months ago)
- Topics: batch-normalization, deep-learning, deep-neural-networks, inference-optimization, pytorch
- Language: Python
- Size: 54.7 KB
- Stars: 194
- Watchers: 8
- Forks: 29
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Batch Norm Fusion for Pytorch
## About
In this repository, we present a simplistic implementation of batchnorm fusion for the most popular CNN architectures in PyTorch.
This package is aimed to speed up the inference at the test time: **expected boost is 30%!** In the future## How it works
We know that both - convolution and batchnorm are the linear operations to the data point x, and they can be written in terms of matrix multiplications:
![T_{bn}*S{bn}*Conv_W*(x)](https://latex.codecogs.com/gif.latex?T_{bn}*S_{bn}*W_{conv}*x),
where we first apply convolution to the data, scale it and eventually shift it using the batchnorm-trained parameters.## Supported architectures
We support any architecture, where Conv and BN are combined in a Sequential module.
If you want to optimize your own networks with this tool, just follow this design.
For the conveniece, we wrapped VGG, ResNet and SeNet families to demonstrate how your models can be converted into such format.- [x] VGG from torchvision.
- [x] ResNet Family from `torchvision`.
- [x] SeNet family from `pretrainedmodels`## How to use
```python
import torchvision.models as models
from bn_fusion import fuse_bn_recursivelynet = getattr(models,'vgg16_bn')(pretrained=True)
net = fuse_bn_recursively(net)
net.eval()
# Make inference with the converted model
```
## TODO- [ ] Tests.
- [ ] Performance benchmarks.## Acknowledgements
Thanks to [@ZFTurbo](https://github.com/ZFTurbo) for the idea, discussions and his [implementation for Keras](https://github.com/ZFTurbo/Keras-inference-time-optimizer).