Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/logcreative/quickdraw-classifier
CS420 Project
https://github.com/logcreative/quickdraw-classifier
Last synced: 2 days ago
JSON representation
CS420 Project
- Host: GitHub
- URL: https://github.com/logcreative/quickdraw-classifier
- Owner: LogCreative
- Created: 2022-05-20T12:44:13.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2022-06-19T09:05:29.000Z (over 2 years ago)
- Last Synced: 2024-12-06T18:57:14.274Z (17 days ago)
- Language: Python
- Homepage:
- Size: 3.94 MB
- Stars: 0
- Watchers: 2
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# quickdraw-classifier
CS420 Project## Data
Unzip the data into `dataset/seq` folder.
Use `RPCL-pix2seq` to covert the seq data into png data. (You may need to clone this repo by `--recursive` parameter to download the submodule.)
```cmd
cd RPCL-pix2seq
python seq2png.py --input_dir=../dataset/seq --output_dir=../dataset/png --png_width=28 --categories={'bear'}
```> **NOTICE** You need to use python<=3.7 to install tensorflow 1.15.
## Train
Once the data is prepared, you could train the model by running python on one of the following scripts:
```
python train_cnn.py
python train_rnn.py
python train_cnnrnn.py
```
For CNN model, you may need to modify the type of the structure in `config_train_cnn.py`. The value of `model` could be `resnet18`, `ResNet`, or `sketchnet`. For RNN model, we use Bidirectional LSTM structure. For CNN-RNN model, we use Sketch-a-Net for CNN branch and BiLSTM for RNN branch.The training process uses PyTorch. During training, the best model will be saved as `best_{model}.pth` in the root folder. The test accuracy could be viewed in [report/result.dat](report/result.dat).
## Report
Report (Chinese) could be found in [report/ML_CS420_Project_report.pdf](report/ML_CS420_Project_report.pdf).