https://github.com/ahirner/pytorch-retraining
Transfer Learning Shootout for PyTorch's model zoo (torchvision)
https://github.com/ahirner/pytorch-retraining
benchmark pytorch transfer-learning
Last synced: 7 months ago
JSON representation
Transfer Learning Shootout for PyTorch's model zoo (torchvision)
- Host: GitHub
- URL: https://github.com/ahirner/pytorch-retraining
- Owner: ahirner
- License: bsd-3-clause
- Created: 2017-05-29T17:09:42.000Z (over 8 years ago)
- Default Branch: master
- Last Pushed: 2020-09-20T11:06:49.000Z (about 5 years ago)
- Last Synced: 2024-08-04T03:08:09.827Z (over 1 year ago)
- Topics: benchmark, pytorch, transfer-learning
- Language: Jupyter Notebook
- Size: 578 KB
- Stars: 170
- Watchers: 10
- Forks: 41
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- Awesome-pytorch-list-CNVersion - pytorch-retraining
- Awesome-pytorch-list - pytorch-retraining
README
# pytorch-retraining
Transfer Learning shootout for PyTorch's model zoo (torchvision).
* **Load** any pretrained model with custom final layer (num_classes) from PyTorch's model zoo in one line
```python
model_pretrained, diff = load_model_merged('inception_v3', num_classes)
```
* **Retrain** minimal (as inferred on load) or a custom amount of layers on multiple GPUs. Optionally with _Cyclical Learning Rate_ [(Smith 2017)](http://arxiv.org/abs/1506.01186).
```python
final_param_names = [d[0] for d in diff]
stats = train_eval(model_pretrained, trainloader, testloader, final_params_names)
```
* **Chart** `training_time`, `evaluation_time` (fps), top-1 `accuracy` for varying levels of retraining depth (shallow, deep and from scratch)
|  |
|:---:|
| *Transfer learning on example dataset [Bee vs Ants](http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)* with 2xV100 GPUs|
## Results on more elaborate Dataset
*num_classes = 23, slightly unbalanced, high variance in rotation and motion blur artifacts* with 1xGTX1080Ti
|  |
|:---:|
| *Constant LR with momentum* |
|  |
|:---:|
| *Cyclical Learning Rate* |