{"id":13577777,"url":"https://github.com/OATML/RHO-Loss","last_synced_at":"2025-04-05T15:31:27.255Z","repository":{"id":37242863,"uuid":"487016239","full_name":"OATML/RHO-Loss","owner":"OATML","description":null,"archived":false,"fork":false,"pushed_at":"2022-10-10T16:23:11.000Z","size":289,"stargazers_count":186,"open_issues_count":2,"forks_count":17,"subscribers_count":6,"default_branch":"main","last_synced_at":"2024-11-05T15:47:49.853Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/OATML.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2022-04-29T15:11:09.000Z","updated_at":"2024-10-25T06:35:49.000Z","dependencies_parsed_at":"2023-01-19T11:16:32.196Z","dependency_job_id":null,"html_url":"https://github.com/OATML/RHO-Loss","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/OATML%2FRHO-Loss","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/OATML%2FRHO-Loss/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/OATML%2FRHO-Loss/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/OATML%2FRHO-Loss/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/OATML","download_url":"https://codeload.github.com/OATML/RHO-Loss/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247358697,"owners_count":20926271,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-08-01T15:01:24.234Z","updated_at":"2025-04-05T15:31:25.015Z","avatar_url":"https://github.com/OATML.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"# Prioritized training on points that are learnable, worth learning, and not yet learned\nSören Mindermann*, Jan M Brauner*, Muhammed T Razzak*, Mrinank Sharma*, Andreas Kirsch, Winnie Xu, Benedikt Höltgen, Aidan N Gomez, Adrien Morisot, Sebastian Farquhar, Yarin Gal \n\n| **[Abstract](#abstract)**\n| **[Installation](#installation)**\n  **[Tutorial](#tutorial)**\n| **[Codebase](#codebase)**\n| **[Citation](#citation)**\n\n[![arXiv](https://img.shields.io/badge/arXiv-2106.02584-b31b1b.svg)](https://arxiv.org/abs/2206.07137)\n[![Python 3.8](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org/downloads/release/python-390/)\n[![Pytorch](https://img.shields.io/badge/Pytorch-1.9-red.svg)](https://shields.io/)\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://GitHub.com/Naereen/StrapDown.js/graphs/commit-activity)\n\nThis is the code for the paper [\"Prioritized training on points that are learnable, worth learning, and not yet learned\"](https://proceedings.mlr.press/v162/mindermann22a.html).\n\nThe code uses PyTorch Lightning, Hydra for config file management, and Weights \u0026 Biases for logging. The codebase is adapted from this [great template](https://github.com/ashleve/lightning-hydra-template).\n\n## Abstract\nTraining on web-scale data can take months. But much computation and time is wasted on redundant and noisy points that are already learnt or not learnable. To accelerate training, we introduce Reducible Holdout Loss Selection (RHO-LOSS), a simple but principled technique which selects approximately those points for training that most reduce the model's generalization loss. As a result, RHO-LOSS mitigates the weaknesses of existing data selection methods: techniques from the optimization literature typically select \"hard\" (e.g. high loss) points, but such points are often noisy (not learnable) or less task-relevant. Conversely, curriculum learning prioritizes \"easy\" points, but such points need not be trained on once learned. In contrast, RHO-LOSS selects points that are learnable, worth learning, and not yet learnt. RHO-LOSS trains in far fewer steps than prior art, improves accuracy, and speeds up training on a wide range of datasets, hyperparameters, and architectures (MLPs, CNNs, and BERT). On the large web-scraped image dataset Clothing-1M, RHO-LOSS trains in 18x fewer steps and reaches 2% higher final accuracy than uniform data shuffling.\n\n## Installation\nConda: ```conda install --file my_environment.yaml```\n\nPoetry: ```poetry install```\n\nThe repository also contains a singularity container definition file that can be built and used to run the experiments. See the ```singularity``` folder.\n\n## Tutorial\n```tutorial.ipynb``` contains the full training pipeline (irreducible loss model training and target model training) on CIFAR-10. This is the best place to start if you want to understand the code or reproduce our results.\n\n## Codebase\nThe codebase contains the functionality for all the experiments in the paper (and more 😜).\n\n### Irreducible loss model training\nStart with ```run_irreducible.py```(which then calls ```src/train_irreducible.py```). The base config file is ```configs/irreducible_training.yaml```.\n\n### Target model training\nStart with ```run.py```(which then calls ```src/train.py```). The base config file is ```configs/config.yaml```. A key file is ```src//models/MultiModels.py```---this is the LightningModule that handles the training loop incl. batch selection. \n\n### More about the code\nThe datamodules are implemented in ```src/datamodules/datamodules.py```, the individual datasets in ```src/datamodules/dataset/sequence_datasets```. If you want to add your own dataset, note that ```__getitem__()``` needs to return the tuple ```(index, input, target)```, where ```index``` is the index of the datapoint with respect to the overall dataset (this is required so that we can match the irreducible losses to the correct datapoints).\n\nAll the selection methods mentioned in the paper (and more) are implemented in ```src/curricula/selection_methods.py```.\n\n### ALBERT fine-tuning\nAll ALBERT experiments are implemented in a separate branch, which is a bit less clean. Good luck :-)\n\n## Reproducibility\nThis repo can be used to reproduce all the experiments in the paper. Check out ```configs/experiment``` for some example experiment configs. The experiment files for the main results are: \n* CIFAR-10: ```cifar10_resnet18_irred.yaml``` and ```cifar10_resnet18_main.yaml```\n* CINIC-10: ```cinic10_resnet18_irred.yaml``` and ```cinic10_resnet18_main.yaml```\n* CIFAR-100: ```cifar100_resnet18_irred.yaml``` and ```cifar100_resnet18_main.yaml```\n* Clothing-1M: ```c1m_resnet18_irred.yaml``` and ```c1m_resnet50_main.yaml```\n\nNLP datasets, on a separate branch:\n* CoLA:\n  * Irreducible loss model training: ```python run_irreducible_nlp.py +experiment=nlp trainer.max_epochs=10 callbacks=val_loss datamodule.task_name=sst2 trainer.val_check_interval=0.05```\n  * Target model training: ```python run_nlp.py +experiment=nlp datamodule.task_name=cola trainer.max_epochs=100 irreducible_loss_generator.f=\\\"path/to/file\" selection_method_nlp=reducible_loss_selection```\n* SST2:\n  * Irreducible loss model training: ```python run_irreducible_nlp.py +experiment=nlp trainer.max_epochs=10 callbacks=val_loss datamodule.task_name=sst2 trainer.val_check_interval=0.05```\n  * Target model training: ```python run_nlp.py +experiment=nlp trainer.max_epochs=15 datamodule.task_name=sst2 +trainer.val_check_interval=0.2 irreducible_loss_generator.f=\\\"path/to/file\" selection_method_nlp=reducible_loss_selection ```\n\n### Notes on using the importance sampling baseline:\nTo run the importance sampling experiments:\n\nImportance sampling on CINIC10\n``` \npython3 run_simple.py datamodule.data_dir=$DATA_DIR +experiment=importance_sampling_baseline.yaml \n```\n\n## Citation\n\nIf you find this code helpful for your work, please cite our paper\n[Paper](https://proceedings.mlr.press/v162/mindermann22a.html) as\n\n```bibtex\n\n@InProceedings{2022PrioritizedTraining,\n  title = \t {Prioritized Training on Points that are Learnable, Worth Learning, and not yet Learnt},\n  author =       {Mindermann, S{\\\"o}ren and Brauner, Jan M and Razzak, Muhammed T and Sharma, Mrinank and Kirsch, Andreas and Xu, Winnie and H{\\\"o}ltgen, Benedikt and Gomez, Aidan N and Morisot, Adrien and Farquhar, Sebastian and Gal, Yarin},\n  booktitle = \t {Proceedings of the 39th International Conference on Machine Learning},\n  pages = \t {15630--15649},\n  year = \t {2022},\n  editor = \t {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},\n  volume = \t {162},\n  series = \t {Proceedings of Machine Learning Research},\n  month = \t {17--23 Jul},\n  publisher =    {PMLR},\n  pdf = \t {https://proceedings.mlr.press/v162/mindermann22a/mindermann22a.pdf},\n  url = \t {https://proceedings.mlr.press/v162/mindermann22a.html},}\n```\n\n## Let us know how it goes!\nIf you've tried RHO-LOSS and it worked well or not, or if you want us to give a presentation at your lab, we'd love to hear it! Correspondance to 'soren.mindermann at cs.ox.ac.uk'\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FOATML%2FRHO-Loss","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FOATML%2FRHO-Loss","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FOATML%2FRHO-Loss/lists"}