https://github.com/borealisai/conr
Contrastive Regularizer
https://github.com/borealisai/conr
Last synced: about 1 year ago
JSON representation
Contrastive Regularizer
- Host: GitHub
- URL: https://github.com/borealisai/conr
- Owner: BorealisAI
- License: other
- Created: 2023-07-31T21:41:12.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2024-02-15T05:42:59.000Z (over 2 years ago)
- Last Synced: 2025-03-22T14:12:06.085Z (about 1 year ago)
- Language: Python
- Size: 2.95 MB
- Stars: 9
- Watchers: 3
- Forks: 1
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# ConR: Contrastive Regularizer for Deep Imbalanced Regression
This repository contains the code for the ICLR 2024 paper:
[__ConR: Contrastive Regularizer for Deep Imbalanced Regression__](https://openreview.net/forum?id=RIuevDSK5V)
Mahsa Keramati, Lili Meng, R David Evans
ConR key insights. a) Without ConR, it
is common to have minority examples mixed with
majority examples. b) ConR adds additional loss
weight for minority, and mis-labelled examples,
resulting in better feature representations and c)
better prediction error.
## Quick Preview
ConR is complementary to conventional imbalanced learning techniques. The following code snippent shows the implementation of ConR for the task of Age estimation
```python
def ConR(features, targets, preds, w=1, weights=1, t=0.07, e=0.01):
q = torch.nn.functional.normalize(features, dim=1)
k = torch.nn.functional.normalize(features, dim=1)
l_k = targets.flatten()[None, :]
l_q = targets
p_k = preds.flatten()[None, :]
p_q = preds
l_dist = torch.abs(l_q - l_k)
p_dist = torch.abs(p_q - p_k)
pos_i = l_dist.le(w)
neg_i = ((~ (l_dist.le(w))) * (p_dist.le(w)))
for i in range(pos_i.shape[0]):
pos_i[i][i] = 0
prod = torch.einsum("nc,kc->nk", [q, k]) / t
pos = prod * pos_i
neg = prod * neg_i
pushing_w = weights * torch.exp(l_dist * e)
neg_exp_dot = (pushing_w * (torch.exp(neg)) * neg_i).sum(1)
# For each query sample, if there is no negative pair, zero-out the loss.
no_neg_flag = (neg_i).sum(1).bool()
# Loss = sum over all samples in the batch (sum over (positive dot product/(negative dot product+positive dot product)))
denom = pos_i.sum(1)
loss = ((-torch.log(
torch.div(torch.exp(pos), (torch.exp(pos).sum(1) + neg_exp_dot).unsqueeze(-1))) * (
pos_i)).sum(1) / denom)
loss = (weights * (loss * no_neg_flag).unsqueeze(-1)).mean()
return loss
```
## Usage
Please go into the sub-folder to run experiments for different datasets.
- [IMDB-WIKI-DIR](./imdb-wiki-dir)
- [AgeDB-DIR](./agedb-dir)
- [NYUD2-DIR](./nyud2-dir)
## Acknowledgment
The code is based on [Yang et al., Delving into Deep Imbalanced Regression, ICML 2021](https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir) and [Ren et al.,Balanced MSE for Imbalanced Visual Regression, CVPR 2022](https://github.com/jiawei-ren/BalancedMSE).