https://github.com/fcakyon/balanced-loss
Easy to use class balanced cross entropy and focal loss implementation for Pytorch
https://github.com/fcakyon/balanced-loss
balanced-loss binary-crossentropy class-balanced-loss computer-vision cross-entropy cvpr deep-learning focal-loss image-classification loss-functions machine-learning pip pypi python pytorch
Last synced: about 1 month ago
JSON representation
Easy to use class balanced cross entropy and focal loss implementation for Pytorch
- Host: GitHub
- URL: https://github.com/fcakyon/balanced-loss
- Owner: fcakyon
- License: mit
- Created: 2022-07-21T08:38:10.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2024-12-17T11:24:31.000Z (5 months ago)
- Last Synced: 2025-03-28T16:05:44.735Z (about 2 months ago)
- Topics: balanced-loss, binary-crossentropy, class-balanced-loss, computer-vision, cross-entropy, cvpr, deep-learning, focal-loss, image-classification, loss-functions, machine-learning, pip, pypi, python, pytorch
- Language: Python
- Homepage: https://pypi.org/project/balanced-loss/
- Size: 30.3 KB
- Stars: 94
- Watchers: 2
- Forks: 8
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
![]()
Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.## Theory
When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.
- First, the effective number of samples are calculated for all classes as:

- Then the class balanced loss function is defined as:

## Installation
```bash
pip install balanced-loss
```## Usage
- Standard losses:
```python
import torch
from balanced_loss import Loss# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
``````python
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
``````python
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
```- Class-balanced losses:
```python
import torch
from balanced_loss import Loss# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
``````python
# class-balanced cross-entropy loss
ce_loss = Loss(
loss_type="cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = ce_loss(logits, labels)
``````python
# class-balanced binary cross-entropy loss
bce_loss = Loss(
loss_type="binary_cross_entropy",
samples_per_class=samples_per_class,
class_balanced=True
)
loss = bce_loss(logits, labels)
```- Customize parameters:
```python
import torch
from balanced_loss import Loss# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively# class-balanced focal loss
focal_loss = Loss(
loss_type="focal_loss",
beta=0.999, # class-balanced loss beta
fl_gamma=2, # focal loss gamma
samples_per_class=samples_per_class,
class_balanced=True
)
loss = focal_loss(logits, labels)
```## Improvements
What is the difference between this repo and vandit15's?
- This repo is a pypi installable package
- This repo implements loss functions as `torch.nn.Module`
- In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
- All typos and errors in vandit15's source are fixed
- Continuously tested on PyTorch 1.13.1 and 2.5.1
- Automatically selects loss module device based on logits## References
https://arxiv.org/abs/1901.05555
https://github.com/richardaecn/class-balanced-loss
https://github.com/vandit15/Class-balanced-loss-pytorch