{"id":13737738,"url":"https://github.com/val-iisc/SDAT","last_synced_at":"2025-05-08T15:31:06.305Z","repository":{"id":41579690,"uuid":"505057704","full_name":"val-iisc/SDAT","owner":"val-iisc","description":"[ICML 2022]Source code for \"A Closer Look at Smoothness in Domain Adversarial Training\", ","archived":false,"fork":false,"pushed_at":"2024-04-11T22:11:25.000Z","size":491,"stargazers_count":68,"open_issues_count":4,"forks_count":13,"subscribers_count":13,"default_branch":"main","last_synced_at":"2025-04-11T16:26:42.280Z","etag":null,"topics":["adversarial-training","dann","domain-adaptation","icml-2022","pytorch","sharpness-aware-minimization"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/val-iisc.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-06-19T08:54:15.000Z","updated_at":"2025-02-10T12:38:42.000Z","dependencies_parsed_at":"2024-01-07T06:01:22.263Z","dependency_job_id":"a5be2033-eacc-4930-b99e-21d30241fc6a","html_url":"https://github.com/val-iisc/SDAT","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/val-iisc%2FSDAT","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/val-iisc%2FSDAT/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/val-iisc%2FSDAT/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/val-iisc%2FSDAT/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/val-iisc","download_url":"https://codeload.github.com/val-iisc/SDAT/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253095916,"owners_count":21853508,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["adversarial-training","dann","domain-adaptation","icml-2022","pytorch","sharpness-aware-minimization"],"created_at":"2024-08-03T03:01:58.976Z","updated_at":"2025-05-08T15:31:05.968Z","avatar_url":"https://github.com/val-iisc.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"# \u003cdiv align=\"center\"\u003eSmooth Domain Adversarial Training\u003c/div\u003e\n\n\u003cfont size = \"3\"\u003e**Harsh Rangwani\\*, Sumukh K Aithal\\*, Mayank Mishra, Arihant Jain, R. Venkatesh Babu**\u003c/font\u003e\n\n\n\nThis is the official PyTorch implementation for our ICML'22 paper: **A Closer Look at Smoothness in Domain Adversarial Training**.[[`Paper`](https://arxiv.org/abs/2206.08213)] \n\n\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-closer-look-at-smoothness-in-domain-1/domain-adaptation-on-office-home)](https://paperswithcode.com/sota/domain-adaptation-on-office-home?p=a-closer-look-at-smoothness-in-domain-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-closer-look-at-smoothness-in-domain-1/domain-adaptation-on-visda2017)](https://paperswithcode.com/sota/domain-adaptation-on-visda2017?p=a-closer-look-at-smoothness-in-domain-1)\n\n## Introduction\n\u003cdiv align=\"center\"\u003e\n  \u003cimg width=\"100%\" alt=\"Smooth Domain Adversarial Training\" src=\"assets/sdat.png\"\u003e\n\u003c/div\u003e\n\n\u003cp align=\"justify\"\u003eIn recent times, methods converging to smooth optima have shown improved generalization for supervised learning tasks like classification. In this work, we analyze the effect of smoothness enhancing formulations on domain adversarial training, the objective of which is a combination of task loss (eg. classification, regression etc.) and adversarial terms. We find that converging to a smooth minima with respect to (w.r.t.) task loss stabilizes the adversarial training leading to better performance on target domain. In contrast to task loss, our analysis shows that converging to smooth minima w.r.t. adversarial loss leads to sub-optimal generalization on the target domain. Based on the analysis, we introduce the Smooth Domain Adversarial Training (SDAT) procedure, which effectively enhances the performance of existing domain adversarial methods for both classification and object detection tasks. \u003c/p\u003e\n\n**TLDR:** Just do a few line of code change to improve your adversarial domain adaptation algorithm by converting it to it's smooth variant. \n\n### Why use SDAT?\n- Can be combined with any DAT algorithm.\n- Easy to integrate with a few lines of code.\n- Leads to significant improvement in the accuracy of target domain.\n\u003c!-- #### DAT Based Method\n ```\n\n# optimizer refers to the standard SGD optimizer which contains parameters of the feature extractor and classifier.\noptimizer.zero_grad()\n# ad_optimizer refers to standard SGD optimizer which contains parameters of domain classifier.\nad_optimizer.zero_grad()\n\nclass_prediction, feature = model(x)\ntask_loss = task_loss_fn(class_prediction, label)\ndomain_loss = domain_classifier(feature)\nloss = task_loss + domain_loss\nloss.backward()\n\n# Update parameters  of feature extractor and classifier\noptimizer.step()\n# Update parameters of domain classifier\nad_optimizer.step()\n``` --\u003e\n\n#### DAT Based Method w/ SDAT\nWe provide the details of changes required to convert any DAT algorithm (eg. CDAN, DANN, CDAN+MCC etc.) to it's Smooth DAT version.\n\n```python\noptimizer = SAM(classifier.get_parameters(), torch.optim.SGD, rho=args.rho, adaptive=False,\n                    lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n# optimizer refers to the Smooth optimizer which contains parameters of the feature extractor and classifier.\noptimizer.zero_grad()\n# ad_optimizer refers to standard SGD optimizer which contains parameters of domain classifier.\nad_optimizer.zero_grad()\n\n# Calculate task loss\nclass_prediction, feature = model(x)\ntask_loss = task_loss_fn(class_prediction, label)\ntask_loss.backward()\n\n# Calculate ϵ̂ (w) and add it to the weights\noptimizer.first_step()\n\n# Calculate task loss and domain loss\nclass_prediction, feature = model(x)\ntask_loss = task_loss_fn(class_prediction, label)\ndomain_loss = domain_classifier(feature)\nloss = task_loss + domain_loss\nloss.backward()\n\n# Update parameters (Sharpness-Aware update)\noptimizer.step()\n# Update parameters of domain classifier\nad_optimizer.step()\n```\n\n## Getting started\n\n* ### Requirements\n\t\u003cul\u003e\n\t\u003cli\u003epytorch 1.9.1\u003c/li\u003e\n\t\u003cli\u003etorchvision 0.10.1\u003c/li\u003e\n\t\u003cli\u003ewandb 0.12.2\u003c/li\u003e\n\t\u003cli\u003etimm 0.5.5\u003c/li\u003e\n\t\u003cli\u003eprettytable 2.2.0\u003c/li\u003e\n\t\u003cli\u003e scikit-learn \u003c/li\u003e\n\t\u003c/ul\u003e\n* ### Installation\n```\ngit clone https://github.com/val-iisc/SDAT.git\ncd SDAT\npip install -r requirements.txt\n```\nWe use Weights and Biases ([wandb](https://wandb.ai/site)) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The ```project``` and ```entity``` arguments in ```wandb.init``` must be changed accordingly. To disable wandb tracking, the ```log_results``` flag can be used. \n\n* ### Datasets\n   The datasets used in the repository can be downloaded from the following links:\n\t   \u003cul\u003e\n\t   \u003cli\u003e[Office-Home](https://www.hemanthdv.org/officeHomeDataset.html)\u003c/li\u003e\u003cli\u003e[VisDA-2017](https://github.com/VisionLearningGroup/taskcv-2017-public) (under classification track)\u003c/li\u003e\u003cli\u003e[DomainNet](http://ai.bu.edu/M3SDA/)\u003c/li\u003e\n\t   \u003c/ul\u003e\n\tThe datasets are automatically downloaded to the ```data/``` folder if it is not available.\n## Training\nWe report our numbers primarily on two domain adaptation methods: CDAN w/ SDAT and CDAN+MCC w/ SDAT. The training scripts can be found under the `examples` subdirectory. \n\n### Domain Adversarial Training (DAT)\nTo train using standard CDAN and CDAN+MCC, use the `cdan.py` and `cdan_mcc.py` files, respectively. Sample command to execute the training of the aforementioned methods with a ViT B-16 backbone,  on Office-Home dataset (with Art as source domain and Clipart as the target domain) can be found below. \n```\npython cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results\n```\n\n### Smooth Domain Adversarial Training (SDAT)\n\nTo train using our proposed CDAN w/ SDAT and CDAN+MCC w/ SDAT, use the `cdan_sdat.py` and `cdan_mcc_sdat.py` files, respectively. \n\n A sample script to run CDAN+MCC w/ SDAT with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) is given below. \n```\npython cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results\n```\nAdditional commands to reproduce the results can be found from `run_office_home.sh` and `run_visda.sh` under `examples`.\n\n### Results\nWe following table reports the accuracy score across the various splits of Office-Home and VisDA-2017 datasets using CDAN+MCC w/ SDAT with VIT B-16 backbone. We also provide downloadable weights for the corresponding pretrained classifier. \n\u003cdiv align=\"center\"\u003e\n\u003ctable\u003e\n    \u003cthead\u003e\n        \u003ctr\u003e\n            \u003cth\u003eDataset\u003c/th\u003e\n            \u003cth\u003eSource\u003c/th\u003e\n            \u003cth\u003eTarget\u003c/th\u003e\n            \u003cth\u003eAccuracy\u003c/th\u003e\n            \u003cth\u003eCheckpoints\u003c/th\u003e\n        \u003c/tr\u003e\n    \u003c/thead\u003e\n    \u003ctbody\u003e\n        \u003ctr\u003e\n            \u003ctd rowspan=12\u003e Office-Home\u003c/td\u003e\n            \u003ctd\u003eArt\u003c/td\u003e\n            \u003ctd\u003eClipart\u003c/td\u003e\n            \u003ctd\u003e70.8\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1KAXa9OpRAh_5pMDslbQdiZAm0rd98oY3\"\u003eckpt\u003c/a\u003e\n            \u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eArt\u003c/td\u003e\n            \u003ctd\u003eProduct\u003c/td\u003e\n            \u003ctd\u003e80.7\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1uYewbS7T-MIJTAyEXnDGsTdDHqpRxhOB\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eArt\u003c/td\u003e\n            \u003ctd\u003eReal World\u003c/td\u003e\n            \u003ctd\u003e90.5\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1fWwc7eFjdozn_5tbtJAFVfd7D7D6-4eJ\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eClipart\u003c/td\u003e\n            \u003ctd\u003eArt\u003c/td\u003e\n            \u003ctd\u003e85.2\u003c/td\u003e\n\t    \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1VjEmev3Q5itkjF2xaBEBkkZoJ-gHHX3q\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eClipart\u003c/td\u003e\n            \u003ctd\u003eProduct\u003c/td\u003e\n            \u003ctd\u003e87.3\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1f19ilEM4DnN3_-n9nf0e6F0ViIB-qcmZ\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eClipart\u003c/td\u003e\n            \u003ctd\u003eReal World\u003c/td\u003e\n            \u003ctd\u003e89.7\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1EZKBMj4LMrUZKV6_I4FPb7bD_fCxv4W-\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eProduct\u003c/td\u003e\n            \u003ctd\u003eArt\u003c/td\u003e\n            \u003ctd\u003e84.1\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1woKuqUay_qSEOLLKF924zA-QYKFek214\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eProduct\u003c/td\u003e\n            \u003ctd\u003eClipart\u003c/td\u003e\n            \u003ctd\u003e70.7\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=16b-EEqDtEVPmuRd89QbQmLtkS3HO7H4T\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n         \u003ctr\u003e\n            \u003ctd\u003eProduct\u003c/td\u003e\n            \u003ctd\u003eReal World\u003c/td\u003e\n            \u003ctd\u003e90.6\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1S7Dm0raEg8WtelOKI2I_xwGFBA0S5Rsc\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eReal World\u003c/td\u003e\n            \u003ctd\u003eArt\u003c/td\u003e\n            \u003ctd\u003e88.3\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1oGOYcJ0SMQH6vXVeU7krs3uhXOIygeha\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n        \u003ctr\u003e\n            \u003ctd\u003eReal World\u003c/td\u003e\n            \u003ctd\u003eClipart\u003c/td\u003e\n            \u003ctd\u003e75.5\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1pjf7V5RWG7kjtGj4bYJOwltDP9UwAJA2\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n         \u003ctr\u003e\n            \u003ctd\u003eReal World\u003c/td\u003e\n            \u003ctd\u003eProduct\u003c/td\u003e\n            \u003ctd\u003e92.1\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1gI6TXw0V-9iXM30SGu2AyMaHp1NZqS8c\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n         \u003ctr\u003e\n\t         \u003ctd rowspan=12\u003eVisDA-2017\u003c/td\u003e\n            \u003ctd\u003eSynthetic\u003c/td\u003e\n            \u003ctd\u003eReal\u003c/td\u003e\n            \u003ctd\u003e89.8\u003c/td\u003e\n            \u003ctd\u003e\u003ca href=\"https://drive.google.com/uc?export=download\u0026id=1-jusx3NM510pC7aOjO5cCGSB-Ywvcgst\"\u003eckpt\u003c/a\u003e\u003c/td\u003e\n        \u003c/tr\u003e\n    \u003c/tbody\u003e\n\u003c/table\u003e\n\u003c/div\u003e\n\n### Evaluation\nTo evaluate a classifier with pretrained weights, use the `eval.py` under `examples`. Set the `--weight_path` argument with the path of the weight to be evaluated. \n\nA sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on Office-Home (with Art as source domain and Clipart as the target domain) is given below.\n```\npython eval.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 -b 24 --no-pool --weight_path path_to_weight.pth --log_name Ar2Cl_cdan_mcc_sdat_vit_eval --gpu 0 --phase test\n```\nA sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on VisDA-2017 (with Synthetic as source domain and Real as the target domain) is given below.\n\n```\npython eval.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --per-class-eval --train-resizing cen.crop --weight_path path_to_weight.pth --log_name visda_cdan_mcc_sdat_vit_eval --gpu 0 --no-pool --phase test\n```\n\n\n## Overview of the arguments\nGenerally, all scripts in the project take the following flags\n- `-a`: Architecture of the backbone. (resnet50|vit_base_patch16_224)\n- `-d`: Dataset (OfficeHome|DomainNet) \n- `-s`: Source Domain\n- `-t`: Target Domain\n- `--epochs`: Number of Epochs to be trained for.\n- `--no-pool`: Use --no-pool for all experiments with ViT backbone.\n- `--log_name`: Name of the run on wandb.\n- `--gpu`: GPU id to use.\n- `--rho`: $\\rho$ value in SDAT (Applicable only for SDAT runs).\n\n## Acknowledgement\nOur implementation is based on the [Transfer Learning Library](https://github.com/thuml/Transfer-Learning-Library). We use the PyTorch implementation of SAM from https://github.com/davda54/sam.\n## Citation\nIf you find our paper or codebase useful, please consider citing us as:\n```latex\n@InProceedings{rangwani2022closer,\n  title={A Closer Look at Smoothness in Domain Adversarial Training},\n  author={Rangwani, Harsh and Aithal, Sumukh K and Mishra, Mayank and Jain, Arihant and Babu, R. Venkatesh},\n booktitle={Proceedings of the 39th International Conference on Machine Learning},\n  year={2022}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fval-iisc%2FSDAT","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fval-iisc%2FSDAT","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fval-iisc%2FSDAT/lists"}