{"id":19839814,"url":"https://github.com/qdata/wigraph","last_synced_at":"2025-05-01T19:30:25.379Z","repository":{"id":83652195,"uuid":"593744170","full_name":"QData/WIGRAPH","owner":"QData","description":"Code for paper \"Improving Interpretability via Explicit Word Interaction Graph Layer\" ","archived":false,"fork":false,"pushed_at":"2023-03-13T20:13:10.000Z","size":2045,"stargazers_count":3,"open_issues_count":0,"forks_count":2,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-04-06T17:05:19.019Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/QData.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-01-26T18:38:17.000Z","updated_at":"2024-03-30T00:49:41.000Z","dependencies_parsed_at":"2024-11-12T12:38:34.877Z","dependency_job_id":null,"html_url":"https://github.com/QData/WIGRAPH","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FWIGRAPH","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FWIGRAPH/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FWIGRAPH/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/QData%2FWIGRAPH/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/QData","download_url":"https://codeload.github.com/QData/WIGRAPH/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":251932522,"owners_count":21667158,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-12T12:24:27.231Z","updated_at":"2025-05-01T19:30:25.373Z","avatar_url":"https://github.com/QData.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# WIGRAPH\nThis repo contains code for the paper ``Improving Interpretability via Explicit Word Interaction Graph Layer`` published in AAAI 2023. \n\nhttps://arxiv.org/abs/2302.02016 \n\nRecent NLP literature has seen growing interest in improving model interpretability. Along this direction, we propose a trainable neural network layer that learns a global interaction graph between words and then selects more informative words using the learned word interactions. Our layer, we call WIGRAPH, can plug into any neural network-based NLP text classifiers right after its word embedding layer. Adding the WIGRAPH layer substantially improves NLP models' interpretability and enhances models' prediction performance at the same time.\n\n``\n@article{sekhon2023improving,\n  title={Improving Interpretability via Explicit Word Interaction Graph Layer},\n  author={Sekhon, Arshdeep and Chen, Hanjie and Shrivastava, Aman and Wang, Zhe and Ji, Yangfeng and Qi, Yanjun},\n  journal={arXiv preprint arXiv:2302.02016},\n  year={2023}\n}\n``\n\n## Requirements \n- Pytorch==1.5.1\n- Transformers==3.3.0\n- datasets=1.1.2\n## Prerequisites\n\nThe `data` folder contains the dataset that are used for WIGRAPH. \nThe datasets used can be downloaded from : [datasets link](https://drive.google.com/file/d/1id1F7N9vXbpL3Y8Omhq2zosT_MTxOT8A/view?usp=share_link). To run the code, save these in `data` folder.\n\nThe `metadata` folder contains precomputed counts of words and interactions and are used to extract a subset of words and interactions used to learn the model. These can be downloaded from [metadata link(precomputed counts)](https://drive.google.com/file/d/1CDQUYJZ7CV_33OU9Or-Q17E2wiCh6o0z/view?usp=share_link). After downloading, save these in `metadata`. \n\nThe `models` used here as classifiers can be downloaded from : [finetuned models](https://drive.google.com/file/d/1id1F7N9vXbpL3Y8Omhq2zosT_MTxOT8A/view?usp=share_link). After downloading, save these in `finedtuned_models/`\n\n## Training \n\nThe wigraph layer that can be plugged in front of any NLP classifier can be found in `imask.py`. \n\nTo train a model using WIGRAPH:\n``\nmain.py --per_gpu_train_batch_size 32 --per_gpu_eval_batch_size 64 --task-name sst2 --learning_rate 1e-05 --factor 1000.0 --mask-hidden-dim 32 --backbone distilbert --save-dir results/distilbert/ --init-mode static --beta-i 0.0 --beta-g 1.0 --beta-s 1.0 --imask-dropout 0.3 --project WIGRAPH --non-linearity gelu --num_train_epochs 10 --seed 42 --onlyA --max_sent_len 56 --anneal\n``\n\nThe relevant hyperparameters are \n1. task-name : [sst2, sst1, trec, subj, imdb, agnews]\n2. backbone : [distilbert, roberta, bert] \n3. beta-i : 0.0 for WIGRAPH-A, 1.0 for WIGRAPH-A-R\n4. beta-s : sparsity weight\n5. onlyA : does not train R, enable for WIGRAPH-A\n5. max_sent_len : [sst2(56), sst1(56), trec(32), subj(128), imdb(512), agnews(128)]\n\n## Interpretation\n\n- We provide interpretation using SHAP as well as LIME. A WIGRAPH model can be interpreted using the following:\n\n ``\ninterpret.py --per_gpu_train_batch_size 32 --per_gpu_eval_batch_size 64 --task-name sst2 --learning_rate 1e-05 --factor 1000.0 --mask-hidden-dim 32 --backbone distilbert  --save-dir results/distilbert/ --init-mode static --beta-i 0.0 --beta-g 1.0 --beta-s 1.0 --imask-dropout 0.3 --project WIGRAPH --non-linearity gelu --num_train_epochs 10 --seed 42 --onlyA --max_sent_len 56 --no-save --anneal\n``\n\n- The aopc (after interpretation files are generated) can be measured using:\n\n``\naopc.py --per_gpu_train_batch_size 32 --per_gpu_eval_batch_size 64 --task-name sst2 --learning_rate 1e-05 --factor 1000.0 --mask-hidden-dim 32 --backbone distilbert  --save-dir results/distilbert/ --init-mode static --beta-i 0.0 --beta-g 1.0 --beta-s 1.0 --imask-dropout 0.3 --project WIGRAPH --non-linearity gelu --num_train_epochs 10 --seed 42 --onlyA --max_sent_len 56 --no-save --anneal\n``\n\n- The hyperparameters for both the above are the same as the training hyperparameters used for training the model.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fqdata%2Fwigraph","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fqdata%2Fwigraph","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fqdata%2Fwigraph/lists"}