https://github.com/jeremyfix/pytorch_feature_extraction
Scripts to extract intermediate features from pytorch models
https://github.com/jeremyfix/pytorch_feature_extraction
Last synced: 2 months ago
JSON representation
Scripts to extract intermediate features from pytorch models
- Host: GitHub
- URL: https://github.com/jeremyfix/pytorch_feature_extraction
- Owner: jeremyfix
- License: gpl-3.0
- Created: 2020-08-07T07:35:17.000Z (almost 5 years ago)
- Default Branch: master
- Last Pushed: 2020-08-07T16:56:05.000Z (almost 5 years ago)
- Last Synced: 2025-02-12T22:19:18.987Z (4 months ago)
- Language: Python
- Size: 41 KB
- Stars: 0
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Pytorch intermediate layer feature extraction
This script allows to extract the features of the intermediate layer of pytorch networks.
The selection of the layers to export is by providing modules idx of the module list of a pytorch model. We make use of
the PyTorch_CIFAR10 pretrained models. For the `dltools.py` script to work, you need to clone this repository
recursively :git clone --recurse-submodules [email protected]:jeremyfix/pytorch_feature_extraction.git
and you also need to download the pretrained networks from
[PyTorch_CIFAR10](https://github.com/huyvnphan/PyTorch_CIFAR10).Check the documentation with
python3 dltools.py --help
Then you can process a single image with
python3 dltools.py --image path/to/an/image
Or the whole CIFAR-10 validation dataset
python3 dltools.py
By default, it is going to process `mobilenet_v2` (the one of PyTorch_CIFAR10) with a batch_size of 128. The features of
the modules 5, 35, 67, 139 and 212 are saved in numpy `npy` files.If your CPU/GPU has not enough memory, you should also consider passing in the `--sequential` flag which is going to
perform one forward pass per intermediate layer preventing to store all the intermediate layers in memory.# Example usage
For example, to save the features of the modules 5, 35, 67, 139, 212 of a `mobilenet_v2` (212 being the last linear layer), processing the image `coq.png`
python3 dltools.py --model_name mobilenet_v2 --modules_idx 5 35 67 139 212 --image coq.png
For example, to save the features of the `maxpool3`, `maxpool4` and last linear linear layer of a `googlenet`, processing the while CIFAR-10 validation set :
python3 dltools.py --model_name googlenet --modules_idx 50 166 215
Note the validation data are shuffled the first labels being 3, 8, 8, 0, 6, 6, ... ;If you do not need the whole dataset, you can process the first validation samples of CIFAR-10 by specifying the `--size` option.