https://github.com/k2-gc/simple-cnn-example
Simple CNN classification example using pytorch. Training, exporting from pytorch to onnx and inference both pytorch and onnxruntime.
https://github.com/k2-gc/simple-cnn-example
classification cnn-classification cnn-pytorch docker onnx onnxruntime python python3 pytorch
Last synced: 3 months ago
JSON representation
Simple CNN classification example using pytorch. Training, exporting from pytorch to onnx and inference both pytorch and onnxruntime.
- Host: GitHub
- URL: https://github.com/k2-gc/simple-cnn-example
- Owner: k2-gc
- Created: 2023-12-01T15:06:32.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2023-12-03T14:16:23.000Z (over 1 year ago)
- Last Synced: 2025-01-11T21:29:20.744Z (5 months ago)
- Topics: classification, cnn-classification, cnn-pytorch, docker, onnx, onnxruntime, python, python3, pytorch
- Language: Python
- Homepage:
- Size: 6.84 KB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Simple-CNN-Example
## Introduction
This repository aims at introducing how to train deep leaerning classification models with Pytorch,
export to onnx and use it with onnxruntime taking MNIST dataset, which is famous for handwriting digit image, as an example.
Generally, CNN model accepts 3channels(RGB) but MNIST has one channel. To deal with this, Custom MNIST Dataset class returns 3channels tensor inheriting "torchvision.dataset.MNIST" class.## Prerequisites
* Docker
* Docker compose
* docker login nvcr.io
* dGPU (Recommended)## How to train
### Train with dGPU
```bash
docker compose -f docker-compose-gpu.yaml up -d
docker exec -it mnist_train /bin/bash
python train.py
```### Train with cpu
```bash
docker compose -f docker-compose.yaml up -d
docker exec -it mnist_train /bin/bash
python train.py
```## Export model from pytorch to onnx
After training, run command bellow.
```bash
python export.py
```## Run onnx with onnxruntime
```bash
python check_onnx_inferenc.py
```
The code above choose 3 sample images from MNIST dataset, infer them and show results of inference of pytorch model and onnx model.