{"id":21280616,"url":"https://github.com/borgwardtlab/sat","last_synced_at":"2025-04-09T18:18:58.219Z","repository":{"id":48157631,"uuid":"497849629","full_name":"BorgwardtLab/SAT","owner":"BorgwardtLab","description":"Official Pytorch code for Structure-Aware Transformer.","archived":false,"fork":false,"pushed_at":"2023-02-21T18:49:42.000Z","size":2711,"stargazers_count":257,"open_issues_count":2,"forks_count":40,"subscribers_count":5,"default_branch":"main","last_synced_at":"2025-04-09T18:18:54.120Z","etag":null,"topics":["graph-neural-networks","graph-representation-learning","graph-transformer","icml-2022"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"bsd-3-clause","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/BorgwardtLab.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-05-30T08:13:58.000Z","updated_at":"2025-04-01T11:46:02.000Z","dependencies_parsed_at":"2024-11-21T13:02:45.313Z","dependency_job_id":null,"html_url":"https://github.com/BorgwardtLab/SAT","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BorgwardtLab%2FSAT","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BorgwardtLab%2FSAT/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BorgwardtLab%2FSAT/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BorgwardtLab%2FSAT/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/BorgwardtLab","download_url":"https://codeload.github.com/BorgwardtLab/SAT/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248085325,"owners_count":21045139,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["graph-neural-networks","graph-representation-learning","graph-transformer","icml-2022"],"created_at":"2024-11-21T10:37:20.499Z","updated_at":"2025-04-09T18:18:58.191Z","avatar_url":"https://github.com/BorgwardtLab.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Structure-Aware Transformer\n\n__Updates: We have added the script for [model visualization](#model-visualization) (Figure 4 in our paper)!__\n\nThe repository implements the Structure-Aware Transformer (SAT) in Pytorch Geometric described in the following paper\n\n\u003eDexiong Chen*, Leslie O'Bray*, and Karsten Borgwardt.\n[Structure-Aware Transformer for Graph Representation Learning][1]. ICML 2022.\n\u003cbr/\u003e*Equal contribution\n\n**TL;DR**: A class of simple and flexible graph transformers built upon a new self-attention mechanism, which incorporates structural information into the original self-attention by extracting a subgraph representation rooted at each node before computing the attention. Our structure-aware framework can leverage any existing GNN to extract the subgraph representation and systematically improve the peroformance relative to the base GNN.\n\n## Citation\n\nPlease use the following to cite our work:\n\n```bibtex\n@InProceedings{Chen22a,\n\tauthor = {Dexiong Chen and Leslie O'Bray and Karsten Borgwardt},\n\ttitle = {Structure-Aware Transformer for Graph Representation Learning},\n\tyear = {2022},\n\tbooktitle = {Proceedings of the 39th International Conference on Machine Learning~(ICML)},\n\tseries = {Proceedings of Machine Learning Research}\n}\n```\n\n\n## A short description of SAT\n\n### SAT vs the vanilla Transformer\n\n![SAT vs Transformer](images/sat_vs_transformer.png)\n\nThe SAT architecture compared with the vanilla transformer architecture is shown above. We make the self-attention calculation in each transformer layer *structure-aware* by leveraging structure-aware node embeddings. We generate these embeddings using a structure extractor (for example, any GNN) on the $k$-hop subgraphs centered at each node of interest. Then, the updated node embeddings are used to compute the query ($\\mathbf{Q}$) and key ($\\mathbf{K}$) matrices. We provide example structure extractors in the next figure.\n\n### Example structure extractors\n\n![Overview figure](images/structure_extractor.png)\n\nThe figure above shows the two example structure extractors used in our paper ($k$-subtree and $k$-subgraph). Structure-aware node representations are generated in the $k$-subtree GNN extractor by using the $k$-hop subtree centered at each node (here, $k=1$) and using a GNN to generate updated node representations. The explicit extraction of the subtree as an initial step is not strictly necessary, as a GNN by nature will use the $k$-hop subtree and generate updated node embeddings using the subtree information. For the $k$-subgraph GNN extractor, we first extract the $k$-hop subgraph centered at each node, and then use a GNN on each subgraph to generate node representations using the full subgraph information. The updated node embeddings are then used to compute the query ($\\mathbf{Q}$) and key ($\\mathbf{K}$) matrices shown in the first figure.\n\n### A quick-start example\n\nBelow you can find a quick-start example on the ZINC dataset, see `./experiments/train_zinc.py` for more details.\n\n\u003cdetails\u003e\u003csummary\u003eclick to see the example:\u003c/summary\u003e\n\n```python\nimport torch\nfrom torch_geometric import datasets\nfrom torch_geometric.loader import DataLoader\nfrom sat.data import GraphDataset\nfrom sat import GraphTransformer\n\n# Load the ZINC dataset using our wrapper GraphDataset,\n# which automatically creates the fully connected graph.\n# For datasets with large graph, we recommend setting return_complete_index=False\n# leading to faster computation\ndset = datasets.ZINC('./datasets/ZINC', subset=True, split='train')\ndset = GraphDataset(dset)\n\n# Create a PyG data loader\ntrain_loader = DataLoader(dset, batch_size=16, shuffle=True)\n\n# Create a SAT model\ndim_hidden = 16\ngnn_type = 'gcn' # use GCN as the structure extractor\nk_hop = 2 # use a 2-layer GCN\n\nmodel = GraphTransformer(\n    in_size=28, # number of node labels for ZINC\n    num_class=1, # regression task\n    d_model=dim_hidden,\n    dim_feedforward=2 * dim_hidden,\n    num_layers=2,\n    batch_norm=True,\n    gnn_type='gcn', # use GCN as the structure extractor\n    use_edge_attr=True,\n    num_edge_features=4, # number of edge labels\n    edge_dim=dim_hidden,\n    k_hop=k_hop,\n    se='gnn', # we use the k-subtree structure extractor\n    global_pool='add'\n)\n\nfor data in train_loader:\n    output = model(data) # batch_size x 1\n    break\n```\n\u003c/details\u003e\n\n## Installation\n\nThe dependencies are managed by [miniconda][2]\n\n```\npython=3.9\nnumpy\nscipy\npytorch=1.9.1\npytorch-geometric=2.0.2\neinops\nogb\n```\n\nOnce you have activated the environment and installed all dependencies, run:\n\n```bash\nsource s\n```\n\nDatasets will be downloaded via Pytorch geometric and OGB package.\n\n## Train SAT on graph and node prediction datasets\n\nAll our experimental scripts are in the folder `experiments`. So to start with, after having run `source s`, run `cd experiments`. The hyperparameters used below are selected as optimal\n\n#### Graph regression on ZINC dataset\n\nTrain a k-subtree SAT with PNA:\n```bash\npython train_zinc.py --abs-pe rw --se gnn --gnn-type pna2 --dropout 0.3 --k-hop 3 --use-edge-attr\n```\n\nTrain a k-subgraph SAT with PNA\n```bash\npython train_zinc.py --abs-pe rw --se khopgnn --gnn-type pna2 --dropout 0.2 --k-hop 3 --use-edge-attr\n```\n\n#### Node classification on PATTERN and CLUSTER datasets\n\nTrain a k-subtree SAT on PATTERN:\n```bash\npython train_SBMs.py --dataset PATTERN --weight-class --abs-pe rw --abs-pe-dim 7 --se gnn --gnn-type pna3 --dropout 0.2 --k-hop 3 --num-layers 6 --lr 0.0003\n```\n\nand on CLUSTER:\n```bash\npython train_SBMs.py --dataset CLUSTER --weight-class --abs-pe rw --abs-pe-dim 3 --se gnn --gnn-type pna2 --dropout 0.4 --k-hop 3 --num-layers 16 --dim-hidden 48 --lr 0.0005\n```\n\n#### Graph classification on OGB datasets\n\n`--gnn-type` can be `gcn`, `gine` or `pna`, where `pna` obtains the best performance.\n\n```bash\n# Train SAT on OGBG-PPA\npython train_ppa.py --gnn-type gcn --use-edge-attr\n```\n\n```bash\n# Train SAT on OGBG-CODE2\npython train_code2.py --gnn-type gcn --use-edge-attr\n```\n\n## Model visualization\n\nWe showcase here how to visualize the attention weights of the [CLS] node learned by SAT and vanilla Transformer with the random walk positional encoding. We have provided the pre-trained models on the Mutagenecity dataset. To visualize the pre-trained models, you need to install the [`networkx`](https://networkx.org/) and `matplotlib` packages, then run:\n\n```bash\npython model_visu.py --graph-idx 2003\n```\n\nThis will generate the following image, the same as the Figure 4 in our paper:\n\n![Model_interpretation](images/graph2003.png)\n\n\n[1]: https://arxiv.org/abs/2202.03036\n[2]: https://conda.io/miniconda.html\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fborgwardtlab%2Fsat","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fborgwardtlab%2Fsat","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fborgwardtlab%2Fsat/lists"}