{"id":19839809,"url":"https://github.com/qdata/textattack-a2t","last_synced_at":"2025-05-01T19:30:24.958Z","repository":{"id":69609556,"uuid":"402605904","full_name":"QData/TextAttack-A2T","owner":"QData","description":"A2T: Towards Improving Adversarial Training of NLP Models (EMNLP 2021 Findings)","archived":false,"fork":false,"pushed_at":"2021-09-12T13:39:14.000Z","size":145372,"stargazers_count":26,"open_issues_count":3,"forks_count":2,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-04-06T17:05:26.492Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/QData.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2021-09-03T01:05:11.000Z","updated_at":"2024-12-26T07:35:08.000Z","dependencies_parsed_at":"2023-03-11T06:34:49.044Z","dependency_job_id":null,"html_url":"https://github.com/QData/TextAttack-A2T","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FTextAttack-A2T","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FTextAttack-A2T/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FTextAttack-A2T/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FTextAttack-A2T/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/QData","download_url":"https://codeload.github.com/QData/TextAttack-A2T/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":251932522,"owners_count":21667158,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-12T12:24:25.948Z","updated_at":"2025-05-01T19:30:24.946Z","avatar_url":"https://github.com/QData.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# A2T: Towards Improving Adversarial Training of NLP Models\n\nThis is the source code for the EMNLP 2021 (Findings) paper [\"Towards Improving Adversarial Training of NLP Models\"](https://arxiv.org/abs/2109.00544).\n\nIf you use the code, please cite the paper:\n```\n@misc{yoo2021improving,\n      title={Towards Improving Adversarial Training of NLP Models}, \n      author={Jin Yong Yoo and Yanjun Qi},\n      year={2021},\n      eprint={2109.00544},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL}\n}\n```\n\n## Prerequisites\nThe work heavily relies on the [TextAttack](https://github.com/QData/TextAttack) package. In fact, the main training code is implemented in the TextAttack package.\n\nRequired packages are listed in the `requirements.txt` file.\n```\npip install -r requirements.txt\n```\n\n## Data\nAll of the data used for the paper are available from HuggingFace's [Datasets](https://huggingface.co/datasets).\n\nFor IMDB and Yelp datasets, because there are no official validation splits, we randomly sampled 5k and 10k, respectively, from the training set and used them as valid splits. We provide the splits in this Google Drive [folder](https://drive.google.com/drive/folders/1-vvSXUzl1PzMzdyZzAWq2dB--m7tEERK?usp=sharing). To use them with the provided code, place each folder (e.g. `imdb`, `yelp`, `augmented_data`) inside `./data` (run `mkdir data`).\n\nAlso, augmented training data generated using SSMBA and back-translation are available in the same folder.\n\n## Training\nTo train BERT model on IMDB dataset with A2T attack for 4 epochs and 1 clean epoch with gamma of 0.2:\n```\npython train.py \\\n    --train imdb \\\n    --eval imdb \\\n    --model-type bert \\\n    --model-save-path ./example \\\n    --num-epochs 4 \\\n    --num-clean-epochs 1 \\\n    --num-adv-examples 0.2 \\\n    --attack-epoch-interval 1 \\\n    --attack a2t \\\n    --learning-rate 5e-5 \\\n    --num-warmup-steps 100 \\\n    --grad-accumu-steps 1 \\\n    --checkpoint-interval-epochs 1 \\\n    --seed 42\n```\n\nYou can also pass `roberta` to train RoBERTa model instead of BERT model. To select other datasets from the paper, pass `rt` (MR), `yelp`, or `snli` for `--train` and `--eval`.\n\nThis script is actually just to run the `Trainer` class from the TextAttack package. To checkout how training is performed, please checkout the `Trainer` [class](https://github.com/QData/TextAttack/blob/master/textattack/trainer.py).\n\n## Evaluation\nTo evalute the accuracy, robustness, and interpretability of our trained model from above, run\n```\npython evaluate.py \\\n    --dataset imdb \\\n    --model-type bert \\\n    --checkpoint-paths ./example_run \\\n    --epoch 4 \\\n    --save-log \\\n    --accuracy \\\n    --robustness \\\n    --attacks a2t a2t_mlm textfooler bae pwws pso \\\n    --interpretability \n```\n\nThis takes the last checkpoint model (`--epoch 4`) and evaluates its accuracy on both IMDB and Yelp dataset (for cross-domain accuracy). It also evalutes the model's robustness against A2T, A2T-MLM, TextFooler, BAE, PWWS, and PSO attacks. Lastly, with the `--interpretability` flag, AOPC scores are calculated. \n\nNote that you will have to run `--robustness` and `--interpretability` with `--accuracy` (or after you separately evaluate accuracy) since both robustness and intepretability evaluations rely on the accuracy evaluation to know which samples the model was able to predict correctly.\nBy default 1000 samples are attacked to evaluate robustness. Likewise, 1000 samples are used to calculate AOPC score for interpretability.\n\nIf you're evaluating multiple models for comparison, it's also advised that you provide all the checkpoint paths together to `--checkpoint-paths`. This is because the samples that are correctly by each model will be different, so we first need to identify the intersection of the all correct predictions before using them to evaluate robustness for all the models. This will allow fairer comparison of models' robustness rather than using attack different samples for each model.\n\n## Data Augmentation\nLastly, we also provide `augment.py` which we used to perform data augmentation methods such as SSMBA and back-translation.\n\nFollowing is an example command for augmenting imdb dataset with SSMBA method.\n```\npython augment.py \\\n    --dataset imdb \\\n    --augmentation ssmba \\\n    --output-path ./augmented_data \\\n    --seed 42 \n```\n\nYou can also pass `backtranslation` to `--augmentation`.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fqdata%2Ftextattack-a2t","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fqdata%2Ftextattack-a2t","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fqdata%2Ftextattack-a2t/lists"}