Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/val-iisc/SDAT

[ICML 2022]Source code for "A Closer Look at Smoothness in Domain Adversarial Training",
https://github.com/val-iisc/SDAT

adversarial-training dann domain-adaptation icml-2022 pytorch sharpness-aware-minimization

Last synced: about 2 months ago
JSON representation

[ICML 2022]Source code for "A Closer Look at Smoothness in Domain Adversarial Training",

Awesome Lists containing this project

README

        

#

Smooth Domain Adversarial Training

**Harsh Rangwani\*, Sumukh K Aithal\*, Mayank Mishra, Arihant Jain, R. Venkatesh Babu**

This is the official PyTorch implementation for our ICML'22 paper: **A Closer Look at Smoothness in Domain Adversarial Training**.[[`Paper`](https://arxiv.org/abs/2206.08213)]

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-closer-look-at-smoothness-in-domain-1/domain-adaptation-on-office-home)](https://paperswithcode.com/sota/domain-adaptation-on-office-home?p=a-closer-look-at-smoothness-in-domain-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-closer-look-at-smoothness-in-domain-1/domain-adaptation-on-visda2017)](https://paperswithcode.com/sota/domain-adaptation-on-visda2017?p=a-closer-look-at-smoothness-in-domain-1)

## Introduction


Smooth Domain Adversarial Training

In recent times, methods converging to smooth optima have shown improved generalization for supervised learning tasks like classification. In this work, we analyze the effect of smoothness enhancing formulations on domain adversarial training, the objective of which is a combination of task loss (eg. classification, regression etc.) and adversarial terms. We find that converging to a smooth minima with respect to (w.r.t.) task loss stabilizes the adversarial training leading to better performance on target domain. In contrast to task loss, our analysis shows that converging to smooth minima w.r.t. adversarial loss leads to sub-optimal generalization on the target domain. Based on the analysis, we introduce the Smooth Domain Adversarial Training (SDAT) procedure, which effectively enhances the performance of existing domain adversarial methods for both classification and object detection tasks.

**TLDR:** Just do a few line of code change to improve your adversarial domain adaptation algorithm by converting it to it's smooth variant.

### Why use SDAT?
- Can be combined with any DAT algorithm.
- Easy to integrate with a few lines of code.
- Leads to significant improvement in the accuracy of target domain.

#### DAT Based Method w/ SDAT
We provide the details of changes required to convert any DAT algorithm (eg. CDAN, DANN, CDAN+MCC etc.) to it's Smooth DAT version.

```python
optimizer = SAM(classifier.get_parameters(), torch.optim.SGD, rho=args.rho, adaptive=False,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
# optimizer refers to the Smooth optimizer which contains parameters of the feature extractor and classifier.
optimizer.zero_grad()
# ad_optimizer refers to standard SGD optimizer which contains parameters of domain classifier.
ad_optimizer.zero_grad()

# Calculate task loss
class_prediction, feature = model(x)
task_loss = task_loss_fn(class_prediction, label)
task_loss.backward()

# Calculate ϵ̂ (w) and add it to the weights
optimizer.first_step()

# Calculate task loss and domain loss
class_prediction, feature = model(x)
task_loss = task_loss_fn(class_prediction, label)
domain_loss = domain_classifier(feature)
loss = task_loss + domain_loss
loss.backward()

# Update parameters (Sharpness-Aware update)
optimizer.step()
# Update parameters of domain classifier
ad_optimizer.step()
```

## Getting started

* ### Requirements


  • pytorch 1.9.1

  • torchvision 0.10.1

  • wandb 0.12.2

  • timm 0.5.5

  • prettytable 2.2.0

  • scikit-learn


* ### Installation
```
git clone https://github.com/val-iisc/SDAT.git
cd SDAT
pip install -r requirements.txt
```
We use Weights and Biases ([wandb](https://wandb.ai/site)) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The ```project``` and ```entity``` arguments in ```wandb.init``` must be changed accordingly. To disable wandb tracking, the ```log_results``` flag can be used.

* ### Datasets
The datasets used in the repository can be downloaded from the following links:


  • [Office-Home](https://www.hemanthdv.org/officeHomeDataset.html)

  • [VisDA-2017](https://github.com/VisionLearningGroup/taskcv-2017-public) (under classification track)

  • [DomainNet](http://ai.bu.edu/M3SDA/)


The datasets are automatically downloaded to the ```data/``` folder if it is not available.
## Training
We report our numbers primarily on two domain adaptation methods: CDAN w/ SDAT and CDAN+MCC w/ SDAT. The training scripts can be found under the `examples` subdirectory.

### Domain Adversarial Training (DAT)
To train using standard CDAN and CDAN+MCC, use the `cdan.py` and `cdan_mcc.py` files, respectively. Sample command to execute the training of the aforementioned methods with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) can be found below.
```
python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
```

### Smooth Domain Adversarial Training (SDAT)

To train using our proposed CDAN w/ SDAT and CDAN+MCC w/ SDAT, use the `cdan_sdat.py` and `cdan_mcc_sdat.py` files, respectively.

A sample script to run CDAN+MCC w/ SDAT with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) is given below.
```
python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
```
Additional commands to reproduce the results can be found from `run_office_home.sh` and `run_visda.sh` under `examples`.

### Results
We following table reports the accuracy score across the various splits of Office-Home and VisDA-2017 datasets using CDAN+MCC w/ SDAT with VIT B-16 backbone. We also provide downloadable weights for the corresponding pretrained classifier.



Dataset
Source
Target
Accuracy
Checkpoints




Office-Home
Art
Clipart
70.8
ckpt



Art
Product
80.7
ckpt


Art
Real World
90.5
ckpt


Clipart
Art
85.2
ckpt


Clipart
Product
87.3
ckpt


Clipart
Real World
89.7
ckpt


Product
Art
84.1
ckpt


Product
Clipart
70.7
ckpt


Product
Real World
90.6
ckpt


Real World
Art
88.3
ckpt


Real World
Clipart
75.5
ckpt


Real World
Product
92.1
ckpt


VisDA-2017
Synthetic
Real
89.8
ckpt

### Evaluation
To evaluate a classifier with pretrained weights, use the `eval.py` under `examples`. Set the `--weight_path` argument with the path of the weight to be evaluated.

A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on Office-Home (with Art as source domain and Clipart as the target domain) is given below.
```
python eval.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 -b 24 --no-pool --weight_path path_to_weight.pth --log_name Ar2Cl_cdan_mcc_sdat_vit_eval --gpu 0 --phase test
```
A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on VisDA-2017 (with Synthetic as source domain and Real as the target domain) is given below.

```
python eval.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --per-class-eval --train-resizing cen.crop --weight_path path_to_weight.pth --log_name visda_cdan_mcc_sdat_vit_eval --gpu 0 --no-pool --phase test
```

## Overview of the arguments
Generally, all scripts in the project take the following flags
- `-a`: Architecture of the backbone. (resnet50|vit_base_patch16_224)
- `-d`: Dataset (OfficeHome|DomainNet)
- `-s`: Source Domain
- `-t`: Target Domain
- `--epochs`: Number of Epochs to be trained for.
- `--no-pool`: Use --no-pool for all experiments with ViT backbone.
- `--log_name`: Name of the run on wandb.
- `--gpu`: GPU id to use.
- `--rho`: $\rho$ value in SDAT (Applicable only for SDAT runs).

## Acknowledgement
Our implementation is based on the [Transfer Learning Library](https://github.com/thuml/Transfer-Learning-Library). We use the PyTorch implementation of SAM from https://github.com/davda54/sam.
## Citation
If you find our paper or codebase useful, please consider citing us as:
```latex
@InProceedings{rangwani2022closer,
title={A Closer Look at Smoothness in Domain Adversarial Training},
author={Rangwani, Harsh and Aithal, Sumukh K and Mishra, Mayank and Jain, Arihant and Babu, R. Venkatesh},
booktitle={Proceedings of the 39th International Conference on Machine Learning},
year={2022}
}
```