{"id":22229738,"url":"https://github.com/graph-com/gsat","last_synced_at":"2025-07-27T19:31:41.287Z","repository":{"id":39829668,"uuid":"453428804","full_name":"Graph-COM/GSAT","owner":"Graph-COM","description":"[ICML 2022] Graph Stochastic Attention (GSAT) for interpretable and generalizable graph learning.","archived":false,"fork":false,"pushed_at":"2024-02-19T17:31:48.000Z","size":1643,"stargazers_count":161,"open_issues_count":0,"forks_count":21,"subscribers_count":5,"default_branch":"main","last_synced_at":"2024-11-28T02:35:14.403Z","etag":null,"topics":["deep-learning","graph-neural-networks","interpretability","interpretable-machine-learning","pytorch","xai"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2201.12987","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/Graph-COM.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-01-29T14:58:30.000Z","updated_at":"2024-11-27T16:19:48.000Z","dependencies_parsed_at":"2024-11-28T02:32:44.414Z","dependency_job_id":"d0295212-0769-4098-8cde-73a211926b8f","html_url":"https://github.com/Graph-COM/GSAT","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Graph-COM%2FGSAT","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Graph-COM%2FGSAT/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Graph-COM%2FGSAT/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Graph-COM%2FGSAT/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/Graph-COM","download_url":"https://codeload.github.com/Graph-COM/GSAT/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":227830912,"owners_count":17826154,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","graph-neural-networks","interpretability","interpretable-machine-learning","pytorch","xai"],"created_at":"2024-12-03T01:12:09.242Z","updated_at":"2024-12-03T01:12:09.938Z","avatar_url":"https://github.com/Graph-COM.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003ch1 align=\"center\"\u003eGraph Stochastic Attention (GSAT)\u003c/h1\u003e\n\u003cp align=\"center\"\u003e\n    \u003ca href=\"https://arxiv.org/abs/2201.12987\"\u003e\u003cimg src=\"https://img.shields.io/badge/-arXiv-grey?logo=gitbook\u0026logoColor=white\" alt=\"arXiv\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://github.com/Graph-COM/GSAT\"\u003e\u003cimg src=\"https://img.shields.io/badge/-Github-grey?logo=github\" alt=\"Github\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://proceedings.mlr.press/v162/miao22a.html\"\u003e \u003cimg alt=\"License\" src=\"https://img.shields.io/static/v1?label=Pub\u0026message=ICML%2722\u0026color=blue\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://colab.research.google.com/drive/1t0_4BxEJ0XncyYvn_VyEQhxwNMvtSUNx?usp=sharing\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Colab\"\u003e\u003c/a\u003e\n    \u003c!-- \u003ca href=\"https://github.com/Graph-COM/GSAT/blob/main/LICENSE\"\u003e \u003cimg alt=\"License\" src=\"https://img.shields.io/github/license/Graph-Com/GSAT?color=blue\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://icml.cc/virtual/2022/spotlight/17430\"\u003e \u003cimg src=\"https://img.shields.io/badge/Video-grey?logo=Kuaishou\u0026logoColor=white\" alt=\"Video\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://icml.cc/media/icml-2022/Slides/17430.pdf\"\u003e \u003cimg src=\"https://img.shields.io/badge/Slides-grey?\u0026logo=MicrosoftPowerPoint\u0026logoColor=white\" alt=\"Slides\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://icml.cc/media/PosterPDFs/ICML%202022/a8acc28734d4fe90ea24353d901ae678.png\"\u003e \u003cimg src=\"https://img.shields.io/badge/Poster-grey?logo=airplayvideo\u0026logoColor=white\" alt=\"Poster\"\u003e\u003c/a\u003e --\u003e\n\u003c/p\u003e\n\n**Blogs ([English](https://towardsdatascience.com/graph-machine-learning-icml-2022-252f39865c70#be75:~:text=and%20inductive%20settings.-,%E2%9E%A1%EF%B8%8F%20Miao%20et%20al,-take%20another%20perspective) - [中文](https://mp.weixin.qq.com/s/aP-XBqFLV0x8h9rtOKU_yg))** |\n**[Slides](https://icml.cc/media/icml-2022/Slides/17430.pdf)** |\n**[Poster](https://icml.cc/media/PosterPDFs/ICML%202022/a8acc28734d4fe90ea24353d901ae678.png)**\n\nThis repository contains the official implementation of GSAT as described in the paper: [Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism](https://arxiv.org/abs/2201.12987) (ICML 2022) by Siqi Miao, Mia Liu, and Pan Li.\n\n## News\n- Mar. 15, 2023: Check out GSAT on [GOOD](https://github.com/divelab/GOOD) benchamrk with leaderboard [here](https://good.readthedocs.io/en/latest/leaderboard.html). GSAT (again) achieves multiple SOTA results on out-of-distribution generalization on the recent benchmark, while being highly interpretable!\n- Jan. 21, 2023: Check out our latest paper [Learnable Randomness Injection (LRI)](https://openreview.net/forum?id=6u7mf9s2A9) with code [here](https://github.com/Graph-COM/LRI), which is recently accepted to ICLR 2023! In LRI, we further generalize the idea of GSAT and propose four datasets with ground-truth interpretation labels from real-world scientific applications (instead of synthetic motif datasets to evaluate interpretability!).\n- Nov. 16, 2022: A bug was reported in the code when averaging edge attention weigts for undirected graphs, as pointed out by this [issue](https://github.com/Graph-COM/GSAT/issues/5). We have fixed this bug in the latest version of the code by this [PR](https://github.com/Graph-COM/GSAT/pull/8).\n\n\n## Introduction\nCommonly used attention mechanisms have been shown to be unable to provide reliable interpretation for graph neural networks (GNNs). So, most previous works focus on developing post-hoc interpretation methods for GNNs.\n\nThis work shows that post-hoc methods suffer from several fundamental issues, such as underfitting the subgraph $G_S$ and overfitting the original input graph $G$. Thus, they are essentially good at checking feature sensitivity but can hardly provide trustworthy interpretation for GNNs if the goal is to extract effective patterns from the data (which should have been the most interesting goal).\n\nThis work addresses those issues by designing an inherently interpretable model. The key idea is to jointly train both the predictor and the explainer with a carefully designed **Graph Stochastic Attention (GSAT)** mechanism. With certain assumptions, GSAT can provide guaranteed  out-of-distribution generalizability and guaranteed inherent interpretability, which makes sure GSAT doesn't suffer from those issues. Fig. 1 shows the architecture of GSAT.\n\n\u003cp align=\"center\"\u003e\u003cimg src=\"./data/arch.png\" width=85% height=85%\u003e\u003c/p\u003e\n\u003cp align=\"center\"\u003e\u003cem\u003eFigure 1.\u003c/em\u003e The architecture of GSAT.\u003c/p\u003e\n\n## Rationale of GSAT\nThe rationale of GSAT is to inject stochasticity when learning attention. For example, Fig 2 shows a task to detect if there exists a five-node-circle in the input graph, so edges with pink end nodes are the critical edges for this task. The main idea of GSAT is the following:\n1. **\u003cins\u003eA regularizer\u003c/ins\u003e** is used to encourage high randomness, i.e. low sampling probability, say `0.7`.\n    - In this case, every critical edge may be dropped `30%` of the time.\n    - Whenever a critical edge is dropped, it may flip model predictions and incur a huge classification loss.\n2. Driven by the **\u003cins\u003eclassification loss\u003c/ins\u003e**, critical edges learn to be with low randomness, i.e. high sampling probability.\n    - With high sampling probabilities (e.g. `1.0`), the critical edges are more likely to be kept during training.\n3. The part with **\u003cins\u003e less randomness\u003c/ins\u003e** is the underlying critical data patterns captured by GSAT.\n\nTo implement the above mechanism, a proper regularizer is needed. As the goal is to control randomness, from an information-theoretic point of view it's to control the amount of information in $G$. So, the information bottleneck (IB) principle can be utilized, which helps to provide guaranteed OOD generalizability and interpretability, see `Theorem. 4.1.` in the paper.\n\n\u003cp align=\"center\"\u003e\u003cimg src=\"./data/rationale.png\" width=85% height=85%\u003e\u003c/p\u003e\n\u003cp align=\"center\"\u003e\u003cem\u003eFigure 2.\u003c/em\u003e The rationale of GSAT.\u003c/p\u003e\n\n## Installation\nWe have tested our code on `Python 3.9` with `PyTorch 1.10.0`, `PyG 2.0.3` and `CUDA 11.3`. Please follow the following steps to create a virtual environment and install the required packages.\n\nClone the repository:\n```\ngit clone https://github.com/Graph-COM/GSAT.git\ncd GSAT\n```\n\nCreate a virtual environment:\n```\nconda create --name gsat python=3.9 -y\nconda activate gsat\n```\n\nInstall dependencies:\n```\nconda install -y pytorch==1.10.0 torchvision cudatoolkit=11.3 -c pytorch\npip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\npip install -r requirements.txt\n```\n\nIn case a lower CUDA version is required, please use the following command to install dependencies:\n```\nconda install -y pytorch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 cudatoolkit=10.2 -c pytorch\npip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.9.0+cu102.html\npip install -r requirements.txt\n```\n\n\n## Run Examples\nWe provide examples with minimal code to run GSAT in `./example/example.ipynb`. We have tested the provided examples on `Ba-2Motifs (GIN)`, `Mutag (GIN)`  and `OGBG-Molhiv (PNA)`. Yet, to implement GSAT* one needs to load a pre-trained model first in the provided example. Also try     \u003ca href=\"https://colab.research.google.com/drive/1t0_4BxEJ0XncyYvn_VyEQhxwNMvtSUNx?usp=sharing\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Colab\"\u003e\u003c/a\u003e to play with `example.ipynb` in Colab.\n\nIt should be able to run on other datasets as well, but some hard-coded hyperparameters might need to be changed accordingly, see `./src/configs` for all hyperparameter settings. To directly reproduce results for other datasets, please follow the instructions in the following section.\n\n## Reproduce Results\nWe provide the source code to reproduce the results in our paper. The results of GSAT can be reproduced by running `run_gsat.py`. To reproduce GSAT*, one needs to first change the configuration file accordingly (`from_scratch: false`).\n\nTo train GSAT or GSAT*:\n```\ncd ./src\npython run_gsat.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]\n```\n\n`dataset_name` can be choosen from `ba_2motifs`, `mutag`, `mnist`, `Graph-SST2`, `spmotif_0.5`, `spmotif_0.7`, `spmotif_0.9`, `ogbg_molhiv`, `ogbg_moltox21`, `ogbg_molbace`, `ogbg_molbbbp`, `ogbg_molclintox`, `ogbg_molsider`.\n\n`model_name` can be choosen from `GIN`, `PNA`.\n\n`GPU_id` is the id of the GPU to use. To use CPU, please set it to `-1`.\n\n\n### Training Logs\nStandard output provides basic training logs, while more detailed logs and interpretation visualizations can be found on tensorboard:\n```\ntensorboard --logdir=./data/[dataset_name]/logs\n```\n\n### Hyperparameter Settings\nAll settings can be found in `./src/configs`.\n\n\n## Instructions on Acquiring Datasets\n- Ba_2Motifs\n    - Raw data files can be downloaded automatically, provided by [PGExplainer](https://arxiv.org/abs/2011.04573) and [DIG](https://github.com/divelab/DIG).\n\n- Spurious-Motif\n    - Raw data files can be generated automatically, provide by [DIR](https://openreview.net/forum?id=hGXij5rfiHw).\n\n- OGBG-Mol\n    - Raw data files can be downloaded automatically, provided by [OGBG](https://ogb.stanford.edu/).\n\n- Mutag\n    - Raw data files need to be downloaded [here](https://github.com/flyingdoog/PGExplainer/tree/master/dataset), provided by [PGExplainer](https://arxiv.org/abs/2011.04573).\n    - Unzip `Mutagenicity.zip` and `Mutagenicity.pkl.zip`.\n    - Put the raw data files in `./data/mutag/raw`.\n\n- Graph-SST2\n    - Raw data files need to be downloaded [here](https://drive.google.com/drive/folders/1dt0aGMBvCEUYzaG00TYu1D03GPO7305z), provided by [DIG](https://github.com/divelab/DIG).\n    - Unzip the downloaded `Graph-SST2.zip`.\n    - Put the raw data files in `./data/Graph-SST2/raw`.\n\n- MNIST-75sp\n    - Raw data files need to be generated following the instruction [here](https://github.com/bknyaz/graph_attention_pool/blob/master/scripts/mnist_75sp.sh).\n    - Put the generated files in `./data/mnist/raw`.\n\n## FAQ\n#### Does GSAT encourage sparsity?\nNo, GSAT doesn't encourage generating sparse subgraphs. We find `r = 0.7` (Eq.(9) in our paper) can generally work well for all datasets in our experiments, which means during training roughly `70%` of edges will be kept (kind of still large). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph, which is what previous works normally do and will hurt performance significantly for inhrently interpretable models (as shown in Fig. 7 in the paper). By contrast, GSAT provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.\n\n#### How to tune the hyperparameters of GSAT?\nWe recommend to tune `r` in `{0.5, 0.7}` and `info_loss_coef` in `{1.0, 0.1, 0.01}` based on validation classification performance. And `r = 0.7` and `info_loss_coef = 1.0` can be a good starting point.\nNote that in practice we would decay the value of `r` gradually during training from `0.9` to the chosen value. Given our empirical observation, the classification performance of GSAT should always be no worse than that yielded by ERM (Empirical Risk Minimization) training, when its hyperparameters are tuned properly.\n\n#### `p` or `α` to implement Eq. (9)?\nRecall in Fig. 1, `p` is the probability of dropping an edge, while `α` is the sampled result from `Bern(p)`. In our provided implementation, as an empirical choice, `α` is used to implement Eq.(9) (the Gumbel-softmax trick makes `α` essentially continuous in practice). We find that when `α` is used it may provide more regularization and make the model more robust to hyperparameters. Nonetheless, using `p` can achieve the same performance.\n\n#### How to sample $G_S$?\nIn practice, we don't yield $G_S$ by doing $\\alpha \\odot A$ in Fig. 1, because based on the gumbel-softmax trick it's non-trivial to make this operation differentiable for message-passing-based neural networks (MPNNs). Instead, the learned attention will act on the message of the corresponding edge. Once the message of an edge is dropped, one can (roughly) believe that the corresponding edge is dropped in MPNNs, and this is like an approximation of $\\alpha \\odot A$.\n\n\u003c!-- #### Can you show an example of how GSAT works?\nBelow we show an example from the `ba_2motifs` dataset, which is to distinguish five-node cycle motifs (left) and house motifs (right).\nTo make good predictions (minimize the cross-entropy loss), GSAT will push the attention weights of those critical edges to be relatively large (ideally close to `1`). Otherwise, those critical edges may be dropped too frequently and thus result in a large cross-entropy loss. Meanwhile, to minimize the regularization loss (the KL divergence term in Eq.(9) of the paper), GSAT will push the attention weights of other non-critical edges to be close to `r`, which is set to be `0.7` in the example. This mechanism of injecting stochasticity makes the learned attention weights from GSAT directly interpretable, since the more critical an edge is, the larger its attention weight will be (the less likely it can be dropped). Note that `ba_2motifs` satisfies our Thm. 4.1 with no noise, and GSAT achieves perfect interpretation performance on it.\n\n\u003cp align=\"center\"\u003e\u003cimg src=\"./data/example.png\" width=85% height=85%\u003e\u003c/p\u003e\n\u003cp align=\"center\"\u003e\u003cem\u003eFigure 2.\u003c/em\u003e An example of the learned attention weights.\u003c/p\u003e --\u003e\n\n\n## Reference\n\nIf you find our paper and repo useful, please cite our paper:\n```bibtex\n@article{miao2022interpretable,\n  title       = {Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism},\n  author      = {Miao, Siqi and Liu, Mia and Li, Pan},\n  journal     = {International Conference on Machine Learning},\n  year        = {2022}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgraph-com%2Fgsat","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgraph-com%2Fgsat","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgraph-com%2Fgsat/lists"}