Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/AdeelH/pytorch-multi-class-focal-loss
An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, generalized to the multi-class case.
https://github.com/AdeelH/pytorch-multi-class-focal-loss
classification deep-learning imbalanced-classes implementation-of-research-paper loss-functions machine-learning multiclass-classification neural-network pytorch pytorch-implementation retinanet
Last synced: 3 months ago
JSON representation
An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, generalized to the multi-class case.
- Host: GitHub
- URL: https://github.com/AdeelH/pytorch-multi-class-focal-loss
- Owner: AdeelH
- License: mit
- Created: 2020-09-03T09:08:36.000Z (about 4 years ago)
- Default Branch: master
- Last Pushed: 2024-01-22T19:03:41.000Z (10 months ago)
- Last Synced: 2024-04-20T17:00:32.817Z (7 months ago)
- Topics: classification, deep-learning, imbalanced-classes, implementation-of-research-paper, loss-functions, machine-learning, multiclass-classification, neural-network, pytorch, pytorch-implementation, retinanet
- Language: Python
- Homepage:
- Size: 27.3 KB
- Stars: 205
- Watchers: 4
- Forks: 24
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
[![DOI](https://zenodo.org/badge/292520399.svg)](https://zenodo.org/badge/latestdoi/292520399)
# Multi-class Focal Loss
An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, https://arxiv.org/abs/1708.02002, generalized to the multi-class case.
It is essentially an enhancement to cross-entropy loss and is useful for classification tasks when there is a large class imbalance. It has the effect of underweighting easy examples.
# Usage
- `FocalLoss` is an `nn.Module` and behaves very much like `nn.CrossEntropyLoss()` i.e.
- supports the `reduction` and `ignore_index` params, and
- is able to work with 2D inputs of shape `(N, C)` as well as K-dimensional inputs of shape `(N, C, d1, d2, ..., dK)`.- Example usage
```python3
focal_loss = FocalLoss(alpha, gamma)
...
inp, targets = batch
out = model(inp)
loss = focal_loss(out, targets)
```# Loading through torch.hub
This repo supports importing modules through `torch.hub`. `FocalLoss` can be easily imported into your code via, for example:
```python3
focal_loss = torch.hub.load(
'adeelh/pytorch-multi-class-focal-loss',
model='FocalLoss',
alpha=torch.tensor([.75, .25]),
gamma=2,
reduction='mean',
force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)
```
Or:
```python3
focal_loss = torch.hub.load(
'adeelh/pytorch-multi-class-focal-loss',
model='focal_loss',
alpha=[.75, .25],
gamma=2,
reduction='mean',
device='cpu',
dtype=torch.float32,
force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)
```