https://github.com/ahmdtaha/simsiam
Pytorch implementation of Exploring Simple Siamese Representation Learning
https://github.com/ahmdtaha/simsiam
deep-learning deep-neural-networks python python3 pytorch pytorch-implementation self-supervised self-supervised-learning simsiam
Last synced: 3 months ago
JSON representation
Pytorch implementation of Exploring Simple Siamese Representation Learning
- Host: GitHub
- URL: https://github.com/ahmdtaha/simsiam
- Owner: ahmdtaha
- Created: 2021-01-08T12:56:42.000Z (almost 5 years ago)
- Default Branch: main
- Last Pushed: 2023-05-19T01:30:05.000Z (over 2 years ago)
- Last Synced: 2025-04-10T16:53:13.339Z (6 months ago)
- Topics: deep-learning, deep-neural-networks, python, python3, pytorch, pytorch-implementation, self-supervised, self-supervised-learning, simsiam
- Language: Python
- Homepage:
- Size: 361 KB
- Stars: 21
- Watchers: 1
- Forks: 8
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# SimSiam
This is an _Unofficial_ PyTorch implementation of **Exploring Simple Siamese Representation Learning [1]**. It is highly inspired by [PatrickHua's repository](https://github.com/PatrickHua/SimSiam). The key features in this repository are1. It supports DistributedDataParallel (DDP) -- not the old fashion DataParallel
2. It separates the loss function (criteria) from the model. Thus, it is easier to modify the loss function (e.g., add a regularizer) without changing the model.
3. It automatically switches between the CIFAR-ResNet and the official ResNet based on the training dataset.
4. The code supports both CIFAR10 and CIFAR100## Requirements
* Python 3+ [Tested on 3.7]
* PyTorch 1.x [Tested on 1.7.1]## Usage example
First update `constants.py` with your dataset dir and checkpoints dir
To train a model `python pretrain_main.py`
To train a linear classifier `python classifier_main.py` -- Make sure to update the path of the [pretrained model](https://github.com/ahmdtaha/simsiam/blob/1975eb6749cc06d1df2a781b4ed9c234ad008c40/classifier_main.py#L107) before running.
The CIFAR hyperparameters are already hard coded in the python script. However these hyperparameters can be overridden by providing at least one parameter when running the script (e.g., `python pretrain_main.py --arch SimCLR`)
By default the code leverage DistributedDataParallel. The are two key parameters to identify the training GPUs
* `--world_size` denotes how many GPUs to use
* `--base_gpu` denotes what is the start GPU index (0-based).
For example if a machine has four GPUs, and user sets world_size=3 and base_gpu=1, the model will be trained on the 2nd, 3rd and 4th GPUs.The code logs the training progress in two files
* train_X.log: Assuming base_gpu=X, the file contains text log of the training progress. There are sample log files inside ./sample_runs dir. These are complete/incomplete runs.
* exp_name.csv: This csv file keeps track of the KNN accuracy at each `--test_interval`To train simCLR model, change the [arch parameter](https://github.com/ahmdtaha/simsiam/blob/cd44e3c73b4b70029202f4277266f3b15e7e54f2/pretrain_main.py#L101) to `simCLR`. This code achieves 89.56% KNN accuracy with simCLR (check ./sample_runs dir).
### Quantitative Evaluation on CIFAR10| | Paper [1] | Ours |
|-------------------|------------|-------------|
| KNN Accracy |||
| Linear Classifier |91.8| [90.27](https://github.com/ahmdtaha/simsiam/blob/main/sample_runs/complete_linear_classifier_run.txt) |## Release History
* 1.0.0
* First commit on 8 Jan 2021### TODO LIST
* Revise the readme file
* ~~Test pretrain_main.py with SimCLR~~
* ~~Add the classifier_main.py~~## References
[1] Chen, Xinlei and He, Kaiming. Representation Learning by Learning to Count. arXiv preprint arXiv:2011.10566