{"id":13577977,"url":"https://github.com/jacobgil/vit-explain","last_synced_at":"2025-04-04T10:05:33.648Z","repository":{"id":41111621,"uuid":"325268341","full_name":"jacobgil/vit-explain","owner":"jacobgil","description":"Explainability for Vision Transformers","archived":false,"fork":false,"pushed_at":"2022-03-12T05:26:04.000Z","size":4313,"stargazers_count":927,"open_issues_count":19,"forks_count":103,"subscribers_count":6,"default_branch":"main","last_synced_at":"2025-03-28T09:07:10.774Z","etag":null,"topics":["deep-learning","explainable-ai","pytorch","transformer","vision-transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/jacobgil.png","metadata":{"files":{"readme":"Readme.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2020-12-29T11:27:52.000Z","updated_at":"2025-03-27T03:51:57.000Z","dependencies_parsed_at":"2022-07-12T18:17:31.122Z","dependency_job_id":null,"html_url":"https://github.com/jacobgil/vit-explain","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jacobgil%2Fvit-explain","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jacobgil%2Fvit-explain/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jacobgil%2Fvit-explain/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jacobgil%2Fvit-explain/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/jacobgil","download_url":"https://codeload.github.com/jacobgil/vit-explain/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247157281,"owners_count":20893220,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","explainable-ai","pytorch","transformer","vision-transformer"],"created_at":"2024-08-01T15:01:25.918Z","updated_at":"2025-04-04T10:05:33.609Z","avatar_url":"https://github.com/jacobgil.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"# Explainability for Vision Transformers (in PyTorch)\n\nThis repository implements methods for explainability in Vision Transformers.\n\nSee also https://jacobgil.github.io/deeplearning/vision-transformer-explainability\n\n\n## Currently implemented:\n- Attention Rollout.\n- Gradient Attention Rollout for class specific explainability. \n*This is our attempt to further build upon and improve Attention Rollout.*\n\n- TBD Attention flow is work in progress.\n\nIncludes some tweaks and tricks to get it working:\n- Different Attention Head fusion methods, \n- Removing the lowest attentions. \n\n\n## Usage\n\n- From code\n``` python\nfrom vit_grad_rollout import VITAttentionGradRollout\n\nmodel = torch.hub.load('facebookresearch/deit:main', \n'deit_tiny_patch16_224', pretrained=True)\ngrad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')\nmask = grad_rollout(input_tensor, category_index=243)\n\n```\n\n- From the command line:\n\n```\npython vit_explain.py --image_path \u003cimage path\u003e --head_fusion \u003cmean, min or max\u003e --discard_ratio \u003cnumber between 0 and 1\u003e --category_index \u003ccategory_index\u003e\n```\nIf category_index isn't specified, Attention Rollout will be used,\notherwise Gradient Attention Rollout will be used.\n\nNotice that by default, this uses the 'Tiny' model from [Training data-efficient image transformers \u0026 distillation through attention](https://arxiv.org/abs/2012.12877)\n hosted on torch hub.\n\n## Where did the Transformer pay attention to in this image?\n\n| Image | Vanilla Attention Rollout  |  With discard_ratio+max fusion |\n| -------------------------|-------------------------|------------------------- |\n| ![](examples/both.png) | ![](examples/both_attention_rollout_0.000_mean.png) | ![](examples/both_attention_rollout_0.990_max.png)\n![](examples/plane.png) | ![](examples/plane_attention_rollout_0.000_mean.png) | ![](examples/plane_attention_rollout_0.900_max.png) |\n![](examples/dogbird.png) | ![](examples/dogbird_attention_rollout_0.000_mean.png) | ![](examples/dogbird_attention_rollout_0.900_max.png) |\n![](examples/plane2.png) | ![](examples/plane2_attention_rollout_0.000_mean.png) | ![](examples/plane2_attention_rollout_0.900_max.png) |\n\n## Gradient Attention Rollout for class specific explainability\n\nThe Attention that flows in the transformer passes along information belonging to different classes.\nGradient roll out lets us see what locations the network paid attention too, \nbut it tells us nothing about if it ended up using those locations for the final classification.\n\nWe can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).\n\n\n### Where does the Transformer see a Dog (category 243), and a Cat (category 282)?\n![](examples/both_grad_rollout_243_0.900_max.png) ![](examples/both_grad_rollout_282_0.900_max.png)\n\n### Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):\n![](examples/dogbird_grad_rollout_161_0.900_max.png) ![](examples/dogbird_grad_rollout_87_0.900_max.png)\n\n\n## Tricks and Tweaks to get this working\n\n### Filtering the lowest attentions in every layer\n\n`--discard_ratio \u003cvalue between 0 and 1\u003e`\n\nRemoves noise by keeping the strongest attentions.\n\nResults for dIfferent values:\n\n![](examples/both_discard_ratio.gif) ![](examples/plane_discard_ratio.gif)\n\n### Different Attention Head Fusions\n\nThe Attention Rollout method suggests taking the average attention accross the attention heads, \n\nbut emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.\n\n` --head_fusion \u003cmean, min or max\u003e`\n\n| Image | Mean Fusion  |  Min Fusion |\n| -------------------------|-------------------------|------------------------- |\n![](examples/both.png) | ![](examples/both_attention_rollout_0.000_mean.png) | ![](examples/both_attention_rollout_0.000_min.png)\n\n## References\n- [Quantifying Attention Flow in Transformers](https://arxiv.org/abs/2005.00928)\n- [timm: a great collection of models in PyTorch](https://github.com/rwightman/pytorch-image-models)\nand especially [the vision transformer implementation](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)\n\n- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)\n- Credit for https://github.com/jeonsworld/ViT-pytorch for being a good starting point.\n\n## Requirements\n`pip install timm`\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjacobgil%2Fvit-explain","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjacobgil%2Fvit-explain","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjacobgil%2Fvit-explain/lists"}