Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/AlanChou/Truncated-Loss
PyTorch implementation of the paper "Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels" in NIPS 2018
https://github.com/AlanChou/Truncated-Loss
Last synced: 7 days ago
JSON representation
PyTorch implementation of the paper "Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels" in NIPS 2018
- Host: GitHub
- URL: https://github.com/AlanChou/Truncated-Loss
- Owner: AlanChou
- Created: 2019-11-11T14:22:56.000Z (about 5 years ago)
- Default Branch: master
- Last Pushed: 2019-11-12T14:05:22.000Z (about 5 years ago)
- Last Synced: 2024-08-02T15:35:53.787Z (3 months ago)
- Language: Python
- Size: 14.6 KB
- Stars: 121
- Watchers: 3
- Forks: 9
- Open Issues: 5
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Truncated Loss (GCE)
This is the unofficial PyTorch implementation of the paper "Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels" in NIPS 2018.
https://arxiv.org/abs/1805.07836## Overview
The code doesn't include the experiment for Fashion MNIST dataset. All hyperparameters are set the same as they are mentioned in the paper. There are two things that are different from the original paper. The first thing is that I didn't seperate a validation set for the pruning step to obtain the optimal-epoch model. I used the model from the best test accuracy epoch to conduct the pruning step which will of course result in better performance. If you want to have fair comparison with other methods, you should seperate a validation set. The second difference is that the loss was averaged instead of summed which I found it to be more stable. I didn't spend time running different niose rate settings. I simply pick noise rate 0.4 to validate on CIFAR-10 and CIFAR-100.
## Dependencies
This code is based on Python 3.5, with the main dependencies being PyTorch==1.2.0 torchvision==0.4.0 Additional dependencies for running experiments are: numpy, argparse, os, csv, sys, PILRun the code with the following example commands:
### Uniform Noise with noise rate 0.4 on CIFAR-10
```
$ CUDA_VISIBLE_DEVICES=0 python3 main.py --dataset cifar10 --noise_type symmetric --noise_rate 0.4 --schedule 40 80 --start_prune 40 --epochs 120
```
### Class Dependent Noise with noise rate 0.4 on CIFAR-10
```
$ CUDA_VISIBLE_DEVICES=0 python3 main.py --dataset cifar10 --noise_type pairflip --noise_rate 0.4 --schedule 40 80 --start_prune 40 --epochs 120
```
### Uniform Noise with noise rate 0.4 on CIFAR-100```
$ CUDA_VISIBLE_DEVICES=0 python3 main.py --dataset cifar100 --noise_type symmetric --noise_rate 0.4 --schedule 80 120 --start_prune 80 --epochs 150
```
### Class Dependent Noise with noise rate 0.4 on CIFAR-100```
$ CUDA_VISIBLE_DEVICES=0 python3 main.py --dataset cifar100 --noise_type pairflip --noise_rate 0.4 --schedule 80 120 --start_prune 80 --epochs 150
```