Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/amazon-science/contrastive_emc2
Code the ICML 2024 paper: "EMC^2: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence"
https://github.com/amazon-science/contrastive_emc2
contrastive-learning deep-neural-networks machine-learning machine-learning-algorithms mcmc-sampling multi-modal multi-modal-learning
Last synced: 22 days ago
JSON representation
Code the ICML 2024 paper: "EMC^2: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence"
- Host: GitHub
- URL: https://github.com/amazon-science/contrastive_emc2
- Owner: amazon-science
- License: apache-2.0
- Created: 2024-05-22T21:31:05.000Z (9 months ago)
- Default Branch: main
- Last Pushed: 2024-05-22T21:47:19.000Z (9 months ago)
- Last Synced: 2024-11-12T23:08:58.847Z (3 months ago)
- Topics: contrastive-learning, deep-neural-networks, machine-learning, machine-learning-algorithms, mcmc-sampling, multi-modal, multi-modal-learning
- Language: Python
- Homepage: https://arxiv.org/abs/2404.10575
- Size: 7.77 MB
- Stars: 1
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
- Code of conduct: CODE_OF_CONDUCT.md
Awesome Lists containing this project
README
## EMC2: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence
**Authors: Chung-Yiu Yau, Hoi-To Wai, Parameswaran Raman, Soumajyoti Sarkar, Mingyi Hong**
This repository contains the PyTorch code used to perform the experiments of the [paper](https://arxiv.org/pdf/2404.10575).
Check system_setup.md if you find your environment not compatible with the code.
Use the following commands to reproduce the main experiment results.
## ResNet-18 on STL-10:
- SimCLR`main_simclr.py --beta 14.28 --lr 1e-4 --disable_batchnorm --pos_batch_size 32 --compute_batch_size 256 --epoch 100 --num_workers 4 --eval_freq 5 --data_root /path/to/data/directory`
- Embedding Cache
`main_gumbel.py --beta 14.28 --lr 1e-4 --rho 0.01024 --neg_batch_size 1 --pos_batch_size 32 --compute_batch_size 512 --disable_batchnorm --disable_proj --sampler gumbel_max --transform_batch_size 1 --num_workers 2 --eval_freq 5 --data_root /path/to/data/directory`
- SogCLR
`main_sogclr.py --beta 14.28 --lr 1e-4 --disable_batchnorm --pos_batch_size 32 --compute_batch_size 256 --epoch 100 --num_workers 4 --eval_freq 5 --data_root /path/to/data/directory`
- EMC2
`main_mcmc.py --beta 14.28 --lr 1e-4 --disable_batchnorm --pos_batch_size 32 --compute_batch_size 256 --epoch 100 --sampler mcmc --mcmc_burn_in 0.5 --num_workers 4 --eval_freq 5 --data_root /path/to/data/directory`
## ResNet-50 on ImageNet-100:
- SimCLR`main_simclr.py --beta 14.28 --lr 1.2 --disable_batchnorm --dataset imagenet100 --model_name resnet50 --dim 128 --mlp_dim 2048 --projector_hidden_layers 1 --weight_decay 1e-6 --optimizer lars --epoch 800 --eval_freq 10 --num_workers 4 --data_root /path/to/data/directory`
- SogCLR
`main_sogclr.py --beta 14.28 --lr 1.2 --disable_batchnorm --dataset imagenet100 --model_name resnet50 --dim 128 --mlp_dim 2048 --projector_hidden_layers 1 --weight_decay 1e-6 --optimizer lars --epoch 800 --eval_freq 10 --num_workers 4 --data_root /path/to/data/directory`
- EMC2
`main_mcmc.py --beta 14.28 --lr 1.2 --disable_batchnorm --dataset imagenet100 --model_name resnet50 --dim 128 --mlp_dim 2048 --projector_hidden_layers 1 --weight_decay 1e-6 --optimizer lars --sampler mcmc --mcmc_burn_in 0.5 --epoch 800 --eval_freq 10 --num_workers 4 --data_root /path/to/data/directory`
## ResNet-18 on STL-10 Subset with SGD:
Use preaugmentation.py to generate the pre-augmented STL-10 with 2 augmentations per image.- SimCLR
`main_simclr.py --beta 5 --lr 1e-3 --disable_batchnorm --pos_batch_size 4 --compute_batch_size 8 --epoch 10 --num_workers 1 --eval_freq 25 --check_gradient_error --finite_aug --n_aug 2 --optimizer sgd --eval_iteration --data_root /path/to/data/directory`
- Embedding Cache
`main_gumbel.py --beta 5 --lr 1e-3 --rho 0.1 --neg_batch_size 1 --pos_batch_size 4 --compute_batch_size 8 --disable_batchnorm --disable_proj --sampler gumbel_max --transform_batch_size 1 --epoch 10 --num_workers 1 --eval_freq 25 --check_gradient_error --finite_aug --n_aug 2 --optimizer sgd --eval_iteration --data_root /path/to/data/directory`
- SogCLR
`main_sogclr.py --beta 5 --lr 1e-3 --disable_batchnorm --pos_batch_size 4 --compute_batch_size 8 --epoch 10 --num_workers 1 --eval_freq 25 --check_gradient_error --finite_aug --n_aug 2 --optimizer sgd --eval_iteration --data_root /path/to/data/directory`
- EMC2
`main_mcmc.py --beta 5 --lr 1e-3 --disable_batchnorm --pos_batch_size 4 --compute_batch_size 8 --epoch 10 --sampler mcmc --mcmc_burn_in 0.5 --num_workers 1 --eval_freq 25 --check_gradient_error --finite_aug --n_aug 2 --optimizer sgd --eval_iteration --data_root /path/to/data/directory`
## Citation
Please consider citing our paper if you use our code:
```text
@misc{yau2024emc2,
title={EMC$^2$: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence},
author={Chung-Yiu Yau and Hoi-To Wai and Parameswaran Raman and Soumajyoti Sarkar and Mingyi Hong},
year={2024},
eprint={2404.10575},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```## Security
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
## License
This project is licensed under the Apache-2.0 License.