Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/1adrianb/pytorch-estimate-flops
Estimate/count FLOPS for a given neural network using pytorch
https://github.com/1adrianb/pytorch-estimate-flops
convolutional-neural-networks deep-learning flops pytorch pytorch-estimate-flops
Last synced: about 15 hours ago
JSON representation
Estimate/count FLOPS for a given neural network using pytorch
- Host: GitHub
- URL: https://github.com/1adrianb/pytorch-estimate-flops
- Owner: 1adrianb
- License: bsd-3-clause
- Created: 2019-02-08T20:40:01.000Z (almost 6 years ago)
- Default Branch: master
- Last Pushed: 2022-05-20T23:36:38.000Z (over 2 years ago)
- Last Synced: 2024-11-07T16:06:41.780Z (7 days ago)
- Topics: convolutional-neural-networks, deep-learning, flops, pytorch, pytorch-estimate-flops
- Language: Python
- Homepage: https://www.adrianbulat.com
- Size: 43.9 KB
- Stars: 304
- Watchers: 9
- Forks: 22
- Open Issues: 6
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
[![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![Test Pytorch Flops Counter](https://github.com/1adrianb/pytorch-estimate-flops/workflows/Test%20Pytorch%20Flops%20Counter/badge.svg)](https://travis-ci.com/1adrianb/pytorch-estimate-flops)
[![PyPI](https://img.shields.io/pypi/v/pthflops.svg?style=flat)](https://pypi.org/project/pthflops/)# pytorch-estimate-flops
Simple pytorch utility that estimates the number of FLOPs for a given network. For now only some basic operations are supported (basically the ones I needed for my models). More will be added soon.
All contributions are welcomed.
## Installation
You can install the model using pip:
```bash
pip install pthflops
```
or directly from the github repository:
```bash
git clone https://github.com/1adrianb/pytorch-estimate-flops && cd pytorch-estimate-flops
python setup.py install
```Note: pytorch 1.8 or newer is recommended.
## Example
```python
import torch
from torchvision.models import resnet18from pthflops import count_ops
# Create a network and a corresponding input
device = 'cuda:0'
model = resnet18().to(device)
inp = torch.rand(1,3,224,224).to(device)# Count the number of FLOPs
count_ops(model, inp)
```Ignoring certain layers:
```python
import torch
from torch import nn
from pthflops import count_opsclass CustomLayer(nn.Module):
def __init__(self):
super(CustomLayer, self).__init__()
self.conv1 = nn.Conv2d(5, 5, 1, 1, 0)
# ... other layers present inside will also be ignoreddef forward(self, x):
return self.conv1(x)# Create a network and a corresponding input
inp = torch.rand(1,5,7,7)
net = nn.Sequential(
nn.Conv2d(5, 5, 1, 1, 0),
nn.ReLU(inplace=True),
CustomLayer()
)# Count the number of FLOPs, jit mode:
count_ops(net, inp, ignore_layers=['CustomLayer'])# Note: if you are using python 1.8 or newer with fx instead of jit, the naming convention changed. As such, you will have to pass ['_2_conv1']
# Please check your model definition to account for this.
# Count the number of FLOPs, fx mode:
count_ops(net, inp, ignore_layers=['_2_conv1'])```