Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/goamegah/pytorch-simclr
Experiments of Pytorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
https://github.com/goamegah/pytorch-simclr
contrastive-learning deep-learning image-classification machine-learning pytorch representation-learning simclr wandb
Last synced: 2 months ago
JSON representation
Experiments of Pytorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
- Host: GitHub
- URL: https://github.com/goamegah/pytorch-simclr
- Owner: goamegah
- Created: 2024-04-15T10:11:25.000Z (8 months ago)
- Default Branch: main
- Last Pushed: 2024-09-13T20:33:37.000Z (3 months ago)
- Last Synced: 2024-09-14T10:52:00.002Z (3 months ago)
- Topics: contrastive-learning, deep-learning, image-classification, machine-learning, pytorch, representation-learning, simclr, wandb
- Language: Python
- Homepage:
- Size: 3.01 MB
- Stars: 4
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Pytorch simCLR experiments
### Original Paper: [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709)
credit: google SimCLR
## Data Constraints
We assume that we have 100 labeled data samples to train a deep learning architecture for image classification tasks. Our dataset is MNIST handwritten digit. So we have 100 images and 100 labels.## How Run Models
In order to run models, try the following commands according to specific model.### LeNet5
```shell
$ python run_lenet.py --mode train --train-epochs 100
```Let's breaking down available flags
- ```-m```, ```--mode```: which mode use during running model (```train``` or ```eval```)
- ```-data```: path to store or get dataset
- ```-dn```, ```--dataset-name```: which dataset use (default ```MNIST```)
- ```-a```, ```--arch```: architecture use as base line
- ```-b```, ```--batch-size```: train batch size
- ```-eval-batch-size```: eval batch size when **eval mode**
- ```--lr```, ```--learning-rate```: learning rate### ResNet-18
```shell
$ python run_resnet.py --mode train --train-epochs 100
```Let's breaking down available flags
- ```-m```, ```--mode```: which mode use during running model (```train``` or ```eval```)
- ```-data```: path to store or get dataset
- ```-dn```, ```--dataset-name```: which dataset use (default ```MNIST```)
- ```-a```, ```--arch```: architecture use as base line
- ```-b```, ```--batch-size```: train batch size
- ```-eval-batch-size```: eval batch size when **eval mode**
- ```--lr```, ```--learning-rate```: learning rate### SimCLR-Resnet18
```shell
$ python run.py --mode train --train-mode finetune --train-epochs 10
```Let's breaking down available flags
- ```-m```, ```--mode```: which mode use during running model (```train``` or ```eval```)
- ```-tm```, ```--train-mode```: type of training (```finetune``` or ```pretrain```)
- **pretrain** for training contrastive layer
- **finetune** for training classifier layer by freezing backbone(ResNet-18) pretrained layer.
- ```-j```, ```--workers```: number of data loading workers
- ```-te```, ```--train-epochs```: number of total epochs to run train'
- ```-ee```, ```--eval-epochs```: number of total epochs to run test
- ```-wd```, ```--weight-decay```: weight decay (default: 1e-4)
- ```-s```, ```--seed```: seed
- ```--out-dim```: projection head out dimension
- ```--temperature```: temperature
- ```--data```: path to store or get dataset
- ```-dn```, ```--dataset-name```: which dataset use (default ```MNIST```)
- ```-a```, ```--arch```: architecture use as base line
- ```-b```, ```--batch-size```: train batch size
- ```--eval-batch-size```: eval batch size when **eval mode**
- ```-lr```, ```--learning-rate```: learning rate## Feature Evaluation
Feature evaluation is done using a linear model protocol.
First, we learned features using SimCLR on the ```MNIST unsupervised``` set.
Then, we train a linear classifier on top of the frozen features from SimCLR.
The linear model is trained on features extracted from the ```MNIST train```
set and evaluated on the ```MNIST test``` set.Check the [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/goamegah/torchSimCLR/blob/main/demos/simclr_eval.ipynb) notebook for reproducibility.
| Method | Architecture |Accuracy |
|---------------------------------------|--------------|----------|
| Supervised baseline | LeNet-5 | `73.73` |
| Supervised baseline | ResNet-18 | `73.26` |
| SimCLR | ResNet-18 | `93.84` |*models are trained on* **100 epochs**.
## Tools and Libraries Used
![Workflow Illustration](./assets/workflow.png)
- numpy >= 1.24.3 (The fundamental package for scientific computing with Python)
- scipy >= 1.10.1 (Additional functions for NumPy)
- pandas >= 2.0.2 (A data frame library)
- matplotlib >= 3.7.1 (A plotting library)
- jupyterlab >= 4.0 (An application for running Jupyter notebooks)
- ipywidgets >= 8.0.6 (Fixes progress bar issues in Jupyter Lab)
- scikit-learn >= 1.2.2 (A general machine learning library)
- watermark >= 2.4.2 (An IPython/Jupyter extension for printing package information)
- torch >= 2.0.1 (The PyTorch deep learning library)
- torchvision >= 0.15.2 (PyTorch utilities for computer vision)
- torchmetrics >= 0.11.4 (Metrics for PyTorch)
- wandb >= 0.17.9 (Web server for Model monitoring)[OPTIONAL PACKAGES]
- TensorboardX
- wandb
- boto3To install these requirements most conveniently, you can use the `requirements.txt` file:
```
pip install -r requirements.txt
```![install-requirements](assets/figures/install-requirements.png)
Then, after completing the installation, please check if all the packages are installed and are up to date using
```
python python_environment_check.py
```![check1](assets/figures/check1.png)
## More installation (Optional)
Fast data loading feedback on Tensorboard (Source: https://github.com/tensorflow/tensorboard/issues/4784)
```shell
$ pip uninstall -y tensorboard tb-nightly &&
$ pip install tb-nightly # must have at least tb-nightly==2.5.0a20210316
```