Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/giochem/mlp-mixer-pytorch
Implementation for paper MLP-Mixer
https://github.com/giochem/mlp-mixer-pytorch
classification mlp-mixer pytorch
Last synced: 19 days ago
JSON representation
Implementation for paper MLP-Mixer
- Host: GitHub
- URL: https://github.com/giochem/mlp-mixer-pytorch
- Owner: giochem
- Created: 2024-06-29T09:41:10.000Z (6 months ago)
- Default Branch: master
- Last Pushed: 2024-06-30T03:21:23.000Z (6 months ago)
- Last Synced: 2024-07-01T00:30:53.987Z (6 months ago)
- Topics: classification, mlp-mixer, pytorch
- Language: Python
- Homepage:
- Size: 240 KB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# [PYTORCH] MLP-Mixer: An all-MLP Architecture for Vision
## Introduction
Implementation for paper [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/pdf/2105.01601.pdf). Give me a star if you like this repo.
Author:
- Github: giochem
- Email: [email protected]
![image](./image/mlp-mixer.png)
[Note] You can use your data to train this model.
## I. Set up environment1. Make sure you have installed pytorch and related libraries.
## II. Set up your dataset.
Create 2 folders `train` and `valid` in the `data` folder (which was created already). Then `Please copy` your images with the corresponding names into these folders.
- `train` folder was used for the training process
- `valid` folder was used for validating training result after each epochThis library use `torchvision.datasets.ImageFolder` API from `Pytorch` to load images. Make sure you have some understanding of how it works via [its document](https://pytorch.org/vision/0.16/generated/torchvision.datasets.ImageFolder.html).
Structure of these folders.
```
train/
...class_a/
......a_image_1.jpg
......a_image_2.jpg
...class_b/
......b_image_1.jpg
......b_image_2.jpg
...class_c/
......c_image_1.jpg
......c_image_2.jpg
``````
valid/
...class_a/
......a_image_1.jpg
......a_image_2.jpg
...class_b/
......b_image_1.jpg
......b_image_2.jpg
...class_c/
......c_image_1.jpg
......c_image_2.jpg
```## III. Train your model by running this command line
```bash
python train.py --epochs ${epochs} --num-classes ${num_classes}
```You want to train a model in 10 epochs for binary classification problems (with 2 classes)
Example:
```bash
python train.py --epochs 10 --num-classes 2
```You can train CIFAR10 dataset without put data to folders
```bash
python train.py --dataset CIFAR10 --epochs 10 --num-classes 10 --image-size 32 --patch-size 4 --batch-size 4 --tokens-mlp-dim 1024 --channels-mlp-dim 128 --hidden-dim 256 --num-of-mlp-blocks 2
```There are some `important` arguments for the script you should consider when running it:
- `dataset`: Get training on CIFAR10 dataset or your dataset
- `train-folder`: The folder of training images
- `valid-folder`: The folder of validation images
- `num-classes`: The number of your problem classes.
- `batch-size`: The batch size of the dataset
- `patch-size`: The patch size of the model (Patch resolution P x P)
- `hidden-dim`: Patch Projection Dimension (Hidden size C)
- `tokens-mlp-dim`: Token-mixing units (MLP dimension Ds).
- `channels-mlp-dim`: Channel-mixing units (MLP dimension Dc).
- `num-of-mlp-blocks`: The number of MLP Blocks (Number of layer)
- `log-path`: Where the model when training writed (loss, acc) and after training saved
- `model-folder`: Where the model after training savedAfter training successfully, your model will be saved to `model-folder` defined before
Some import params you should follow:
![image](./image/params.png)
## IV. Testing model with a new image
We offer a script for testing a model using a new image via a command line:
```bash
python predict.py --test-file-path ${test_file_path}
```where `test_file_path` is the path of your test image.
Example:
```bash
python predict.py --test-file-path ./data/test/cat.2000.jpg
```## IV. Contribute
If you meet any issues when using this library, please let me know via the issues submission tab.