{"id":13738041,"url":"https://github.com/xxxnell/how-do-vits-work","last_synced_at":"2026-01-16T06:03:09.468Z","repository":{"id":41331086,"uuid":"345953439","full_name":"xxxnell/how-do-vits-work","owner":"xxxnell","description":"(ICLR 2022 Spotlight) Official PyTorch implementation of \"How Do Vision Transformers Work?\"","archived":false,"fork":false,"pushed_at":"2022-07-14T18:53:23.000Z","size":19155,"stargazers_count":815,"open_issues_count":5,"forks_count":79,"subscribers_count":7,"default_branch":"transformer","last_synced_at":"2025-05-08T15:42:58.108Z","etag":null,"topics":["loss-landscape","pytorch","self-attention","transformer","vision-transformer"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2202.06709","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/xxxnell.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2021-03-09T09:33:48.000Z","updated_at":"2025-05-04T21:25:21.000Z","dependencies_parsed_at":"2022-08-10T01:54:26.799Z","dependency_job_id":null,"html_url":"https://github.com/xxxnell/how-do-vits-work","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/xxxnell/how-do-vits-work","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xxxnell%2Fhow-do-vits-work","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xxxnell%2Fhow-do-vits-work/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xxxnell%2Fhow-do-vits-work/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xxxnell%2Fhow-do-vits-work/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/xxxnell","download_url":"https://codeload.github.com/xxxnell/how-do-vits-work/tar.gz/refs/heads/transformer","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xxxnell%2Fhow-do-vits-work/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":28477598,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-01-16T03:13:13.607Z","status":"ssl_error","status_checked_at":"2026-01-16T03:11:47.863Z","response_time":107,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["loss-landscape","pytorch","self-attention","transformer","vision-transformer"],"created_at":"2024-08-03T03:02:09.597Z","updated_at":"2026-01-16T06:03:09.424Z","avatar_url":"https://github.com/xxxnell.png","language":"Python","funding_links":[],"categories":["Python","其他_机器视觉"],"sub_categories":["网络服务_其他"],"readme":"\n\n# How Do Vision Transformers Work?\n\n[[arxiv](https://arxiv.org/abs/2202.06709), [poster](https://github.com/xxxnell/how-do-vits-work-storage/blob/master/resources/how_do_vits_work_poster_iclr2022.pdf), [slides](https://github.com/xxxnell/how-do-vits-work-storage/blob/master/resources/how_do_vits_work_talk.pdf)]\n\nThis repository provides a PyTorch implementation of [\"How Do Vision Transformers Work? (ICLR 2022 Spotlight)\"](https://arxiv.org/abs/2202.06709) In the paper, we show that the success of multi-head self-attentions (MSAs) for computer vision ***does NOT lie in their weak inductive bias and the capturing of long-range dependencies***. MSAs are not merely generalized Convs, but rather generalized spatial smoothings that *complement* Convs.\nIn particular, we address the following three key questions of MSAs and Vision Transformers (ViTs): \n\n***Q1. What properties of MSAs do we need to better optimize NNs?***  \n\nA1. MSAs have their pros and cons. MSAs improve NNs by flattening the loss landscapes. A key feature is their data specificity (data dependency), not long-range dependency. On the other hand, ViTs suffers from non-convex losses.\n\n\n***Q2. Do MSAs act like Convs?***  \n\nA2. MSAs and Convs exhibit opposite behaviors—e.g., MSAs are low-pass filters, but Convs are high-pass filters. It suggests that MSAs are shape-biased, whereas Convs are texture-biased. Therefore, MSAs and Convs are complementary.\n\n\n***Q3. How can we harmonize MSAs with Convs?***  \n\nA3. MSAs at the end of a stage (not a model) significantly improve the accuracy. Based on this, we introduce *AlterNet* by replacing Convs at the end of a stage with MSAs. AlterNet outperforms CNNs not only in large data regimes but also in small data regimes.\n\n\n👇 Let's find the detailed answers below!\n\n\n### I. What Properties of MSAs Do We Need to Improve Optimization?\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"resources/vit/loss-landscape.png\" style=\"width:83%;\"\u003e\n\u003c/p\u003e\n\nMSAs improve not only accuracy but also generalization by flattening the loss landscapes (reducing the magnitude of Hessian eigenvalues). ***Such improvement is primarily attributable to their data specificity, NOT long-range dependency*** 😱 On the other hand, ViTs suffers from non-convex losses (negative Hessian eigenvalues). Their weak inductive bias and long-range dependency produce negative Hessian eigenvalues in small data regimes, and these non-convex points disrupt NN training. Large datasets and loss landscape smoothing methods alleviate this problem.\n\n\n### II. Do MSAs Act Like Convs?\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"resources/vit/fourier.png\" style=\"width:83%;\"\u003e\n\u003c/p\u003e\n\nMSAs and Convs exhibit opposite behaviors. Therefore, MSAs and Convs are complementary. For example, MSAs are low-pass filters, but Convs are high-pass filters. Likewise, Convs are vulnerable to high-frequency noise but that MSAs are vulnerable to low-frequency noise: it suggests that MSAs are shape-biased, whereas Convs are texture-biased. In addition, Convs transform feature maps and MSAs aggregate transformed feature map predictions. Thus, it is effective to place MSAs after Convs.\n\n\n### III. How Can We Harmonize MSAs With Convs?\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"resources/vit/architecture.png\" style=\"width:83%;\"\u003e\n\u003c/p\u003e\n\nMulti-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage (not the end of a model) play a key role in prediction. Considering these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of the canonical Transformer, which has one MLP block for one MSA block.\n\nBased on these design rules, we introduce AlterNet ([code](https://github.com/xxxnell/how-do-vits-work/blob/transformer/models/alternet.py)) by replacing Conv blocks at the end of a stage with MSA blocks. ***Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes***, e.g., CIFAR. This contrasts with canonical ViTs, models that perform poorly on small amounts of data. For more details, see below ([\"How to Apply MSA to Your Own Model\"](#how-to-apply-msa-to-your-own-model) section).\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"resources/vit/summary.png\" style=\"width:70%;\"\u003e\n\u003c/p\u003e\n\nBut why do Vision Transformers work that way? Our recent paper, [\"Blurs Behaves Like Ensembles: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness (ICML 2022)\"](https://arxiv.org/abs/2105.12639) ([code and summary](https://github.com/xxxnell/spatial-smoothing) :octocat:, [poster](https://github.com/xxxnell/spatial-smoothing-storage/blob/master/resources/blurs_behave_like_ensembles_poster_icml2022.pdf)), shows that even a simple (non-trainable) 2 ✕ 2 box blur filter has the same properties. Spatial smoothings improve accuracy, uncertainty, and robustness simultaneously by *ensembling* spatially nearby feature maps of CNNs and flattening loss landscapes, and self-attentions can be deemed as trainable importance-weighted ensembles of feature maps. In conclusion, MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!\n\n\n\n\n## Getting Started \n\nThe following packages are required:\n\n* pytorch\n* matplotlib\n* notebook\n* ipywidgets\n* timm\n* einops\n* tensorboard\n* seaborn (optional)\n\nWe mainly use docker images `pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime` for the code. \n\nSee [```classification.ipynb```](classification.ipynb) ([Colab notebook](https://colab.research.google.com/github/xxxnell/how-do-vits-work/blob/transformer/classification.ipynb)) for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet. \n\n**Metrics.** We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, ↑) and Acc for 90% certain results (Acc-90, ↑), negative log-likelihood (NLL, ↓), Expected Calibration Error (ECE, ↓), Intersection-over-Union (IoU, ↑) and IoU for certain results (IoU-90, ↑), Unconfidence (Unc-90, ↑), and Frequency for certain results (Freq-90, ↑). We also define a method to plot a reliability diagram for visualization.\n\n**Models.** We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default. timm implementations also can be used.\n\n\n\n\n\n\n\u003cdetails\u003e\n\u003csummary\u003e\n  Pretrained models for CIFAR-100 are also provided: \u003ca href=\"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar\"\u003eResNet-50\u003c/a\u003e, \u003ca href=\"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar\"\u003eViT-Ti\u003c/a\u003e, \u003ca href=\"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/pit_ti_cifar100_0645889efb.pth.tar\"\u003ePiT-Ti\u003c/a\u003e, and \u003ca href=\"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/swin_ti_cifar100_ec2894492b.pth.tar\"\u003eSwin-Ti\u003c/a\u003e. We recommend using \u003ca href=\"https://github.com/rwightman/pytorch-image-models\"\u003etimm\u003c/a\u003e for ImageNet-1K for the sake of simplicity (e.g., please refer to \u003ccode\u003e\u003ca href=\"https://github.com/xxxnell/how-do-vits-work/blob/transformer/fourier_analysis.ipynb\"\u003efourier_analysis.ipynb\u003c/a\u003e\u003c/code\u003e).\n  \u003c/summary\u003e\n\u003cbr/\u003e\nThe codes below are snippets for (a) loading pretrained models and (b) converting them into block sequences.\n  \u003cbr/\u003e\n\n```python\n# ResNet-50\nimport models\n  \n# a. download and load a pretrained model for CIFAR-100\nurl = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar\"\npath = \"checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar\"\nmodels.download(url=url, path=path)\n\nname = \"resnet_50\"\nmodel = models.get_model(name, num_classes=100,  # timm does not provide a ResNet for CIFAR\n                         stem=model_args.get(\"stem\", False))\nmap_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\ncheckpoint = torch.load(path, map_location=map_location)\nmodel.load_state_dict(checkpoint[\"state_dict\"])\n\n# b. model → blocks. `blocks` is a sequence of blocks\nblocks = [\n    model.layer0,\n    *model.layer1,\n    *model.layer2,\n    *model.layer3,\n    *model.layer4,\n    model.classifier,\n]\n```\n\n```python\n# ViT-Ti\nimport copy\nimport timm\nimport torch\nimport torch.nn as nn\nimport models\n\n# a. download and load a pretrained model for CIFAR-100\nurl = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar\"\npath = \"checkpoints/vit_ti_cifar100_9857b21357.pth.tar\"\nmodels.download(url=url, path=path)\n\nmodel = timm.models.vision_transformer.VisionTransformer(\n    num_classes=100, img_size=32, patch_size=2,  # for CIFAR\n    embed_dim=192, depth=12, num_heads=3, qkv_bias=False,  # for ViT-Ti \n)\nmodel.name = \"vit_ti\"\nmodels.stats(model)\nmap_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\ncheckpoint = torch.load(path, map_location=map_location)\nmodel.load_state_dict(checkpoint[\"state_dict\"])\n\n\n# b. model → blocks. `blocks` is a sequence of blocks\n\nclass PatchEmbed(nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self.model = copy.deepcopy(model)\n        \n    def forward(self, x, **kwargs):\n        x = self.model.patch_embed(x)\n        cls_token = self.model.cls_token.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_token, x), dim=1)\n        x = self.model.pos_drop(x + self.model.pos_embed)\n        return x\n\n\nclass Residual(nn.Module):\n    def __init__(self, *fn):\n        super().__init__()\n        self.fn = nn.Sequential(*fn)\n        \n    def forward(self, x, **kwargs):\n        return self.fn(x, **kwargs) + x\n    \n    \nclass Lambda(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n        \n    def forward(self, x):\n        return self.fn(x)\n\n\ndef flatten(xs_list):\n    return [x for xs in xs_list for x in xs]\n\n\n# model → blocks. `blocks` is a sequence of blocks\nblocks = [\n    PatchEmbed(model),\n    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] \n              for b in model.blocks]),\n    nn.Sequential(model.norm, Lambda(lambda x: x[:, 0]), model.head),\n]\n```\n\n  \n```python\n# PiT-Ti\nimport copy\nimport math\nimport timm\n\nimport torch\nimport torch.nn as nn\n\n# a. download and load a pretrained model for CIFAR-100\nurl = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/pit_ti_cifar100_0645889efb.pth.tar\"\npath = \"checkpoints/pit_ti_cifar100_0645889efb.pth.tar\"\nmodels.download(url=url, path=path)\n\nmodel = timm.models.pit.PoolingVisionTransformer(\n    num_classes=100, img_size=32, patch_size=2, stride=1,  # for CIFAR-100\n    base_dims=[32, 32, 32], depth=[2, 6, 4], heads=[2, 4, 8], mlp_ratio=4,  # for PiT-Ti\n)\nmodel.name = \"pit_ti\"\nmodels.stats(model)\nmap_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\ncheckpoint = torch.load(path, map_location=map_location)\nmodel.load_state_dict(checkpoint[\"state_dict\"])\n\n\n# b. model → blocks. `blocks` is a sequence of blocks\n\nclass PatchEmbed(nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self.model = copy.deepcopy(model)\n        \n    def forward(self, x, **kwargs):\n        x = self.model.patch_embed(x)\n        x = self.model.pos_drop(x + self.model.pos_embed)\n        cls_tokens = self.model.cls_token.expand(x.shape[0], -1, -1)\n\n        return (x, cls_tokens)\n\n    \nclass Concat(nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self.model = copy.deepcopy(model)\n        \n    def forward(self, x, **kwargs):\n        x, cls_tokens = x\n        B, C, H, W = x.shape\n        token_length = cls_tokens.shape[1]\n\n        x = x.flatten(2).transpose(1, 2)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        return x\n    \n    \nclass Pool(nn.Module):\n    def __init__(self, block, token_length):\n        super().__init__()\n        self.block = copy.deepcopy(block)\n        self.token_length = token_length\n        \n    def forward(self, x, **kwargs):\n        cls_tokens = x[:, :self.token_length]\n        x = x[:, self.token_length:]\n        B, N, C = x.shape\n        H, W = int(math.sqrt(N)), int(math.sqrt(N))\n        x = x.transpose(1, 2).reshape(B, C, H, W)\n\n        x, cls_tokens = self.block(x, cls_tokens)\n        \n        return x, cls_tokens\n    \n    \nclass Classifier(nn.Module):\n    def __init__(self, norm, head):\n        super().__init__()\n        self.head = copy.deepcopy(head)\n        self.norm = copy.deepcopy(norm)\n        \n    def forward(self, x, **kwargs):\n        x = x[:,0]\n        x = self.norm(x)\n        x = self.head(x)\n        return x\n\n    \nclass Residual(nn.Module):\n    def __init__(self, *fn):\n        super().__init__()\n        self.fn = nn.Sequential(*fn)\n        \n    def forward(self, x, **kwargs):\n        return self.fn(x, **kwargs) + x\n\n    \ndef flatten(xs_list):\n    return [x for xs in xs_list for x in xs]\n\n\nblocks = [\n    nn.Sequential(PatchEmbed(model), Concat(model),),\n    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] \n              for b in model.transformers[0].blocks]),\n    nn.Sequential(Pool(model.transformers[0].pool, 1), Concat(model),),\n    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] \n              for b in model.transformers[1].blocks]),\n    nn.Sequential(Pool(model.transformers[1].pool, 1), Concat(model),),\n    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] \n              for b in model.transformers[2].blocks]),\n    Classifier(model.norm, model.head),\n]\n```\n\n\n```python\n# Swin-Ti\nimport copy\nimport timm\nimport models\n\nimport torch\nimport torch.nn as nn\n\n# a. download and load a pretrained model for CIFAR-100\nurl = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/swin_ti_cifar100_ec2894492b.pth.tar\"\npath = \"checkpoints/swin_ti_cifar100_ec2894492b.pth.tar\"\nmodels.download(url=url, path=path)\n\nmodel = timm.models.swin_transformer.SwinTransformer(\n    num_classes=100, img_size=32, patch_size=1, window_size=4,  # for CIFAR-100\n    embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), qkv_bias=False,  # for Swin-Ti\n)\nmodel.name = \"swin_ti\"\nmodels.stats(model)\nmap_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\ncheckpoint = torch.load(path, map_location=map_location)\nmodel.load_state_dict(checkpoint[\"state_dict\"])\n\n\n# b. model → blocks. `blocks` is a sequence of blocks\n\nclass Attn(nn.Module):\n    def __init__(self, block):\n        super().__init__()\n        self.block = copy.deepcopy(block)\n        self.block.mlp = nn.Identity()\n        self.block.norm2 = nn.Identity()\n        \n    def forward(self, x, **kwargs):\n        x = self.block(x)\n        x = x / 2\n        \n        return x\n\nclass MLP(nn.Module):\n    def __init__(self, block):\n        super().__init__()\n        block = copy.deepcopy(block)\n        self.mlp = block.mlp\n        self.norm2 = block.norm2\n        \n    def forward(self, x, **kwargs):\n        x = x + self.mlp(self.norm2(x))\n\n        return x\n\n    \nclass Classifier(nn.Module):\n    def __init__(self, norm, head):\n        super().__init__()\n        self.norm = copy.deepcopy(norm)\n        self.head = copy.deepcopy(head)\n        \n    def forward(self, x, **kwargs):\n        x = self.norm(x)\n        x = x.mean(dim=1)\n        x = self.head(x)\n\n        return x\n\n    \ndef flatten(xs_list):\n    return [x for xs in xs_list for x in xs]\n\n\nblocks = [\n    model.patch_embed,\n    *flatten([[Attn(block), MLP(block)] for block in model.layers[0].blocks]),\n    model.layers[0].downsample,\n    *flatten([[Attn(block), MLP(block)] for block in model.layers[1].blocks]),\n    model.layers[1].downsample,\n    *flatten([[Attn(block), MLP(block)] for block in model.layers[2].blocks]),\n    model.layers[2].downsample,\n    *flatten([[Attn(block), MLP(block)] for block in model.layers[3].blocks]),\n    Classifier(model.norm, model.head)\n]\n```\n\u003c/details\u003e\n\n\n\n## Fourier Analysis of Representations \n\nRefer to [```fourier_analysis.ipynb```](fourier_analysis.ipynb) ([Colab notebook](https://colab.research.google.com/github/xxxnell/how-do-vits-work/blob/transformer/fourier_analysis.ipynb)) to analyze feature maps through the lens of Fourier transform. Run all cells to visualize Fourier transformed feature maps. Fourier analysis shows that MSAs reduce high-frequency signals, while Convs amplified high-frequency components.\n\n\n## Measuring Feature Map Variances\n\nRefer to [```featuremap_variance.ipynb```](featuremap_variance.ipynb) ([Colab notebook](https://colab.research.google.com/github/xxxnell/how-do-vits-work/blob/transformer/featuremap_variance.ipynb)) to measure feature map variance. Run all cells to visualize feature map variances. Feature map variance shows that MSAs aggregate feature maps, but Convs and MLPs diversify them.\n\n\n## Visualizing the Loss Landscapes\n\nRefer to [```losslandscape.ipynb```](losslandscape.ipynb) ([Colab notebook](https://colab.research.google.com/github/xxxnell/how-do-vits-work/blob/transformer/losslandscape.ipynb)) or [the original repo](https://github.com/tomgoldstein/loss-landscape) for exploring the loss landscapes. Run all cells to get predictive performance of the model for weight space grid. Loss landscape visualization shows that ViT has a flatter loss than ResNet.\n\n\n## Evaluating Robustness on Corrupted Datasets\n\nRefer to [```robustness.ipynb```](robustness.ipynb) ([Colab notebook](https://colab.research.google.com/github/xxxnell/how-do-vits-work/blob/transformer/robustness.ipynb)) for evaluation corruption robustness on [corrupted datasets](https://github.com/hendrycks/robustness) such as CIFAR-10-C and CIFAR-100-C. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each. \n\n\n## How to Apply MSA to Your Own Model\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"resources/vit/buildup_v.gif\" style=\"width:90%;\"\u003e\n\u003c/p\u003e\n\nWe find that MSA complements Conv (not replaces Conv), and *MSA closer to the end of a stage* improves predictive performance significantly. Based on these insights, we propose the following build-up rules:\n\n1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model. \n2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA. \n3. Use more heads and higher hidden dimensions for MSA blocks in late stages.\n\nIn the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in `c3` harm the accuracy, but the MSA at the end of `c2` improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regimes, e.g., CIFAR-100!\n\n\n\n\n## Investigate Loss Landscapes and Hessians With L2 Regularization on Augmented Datasets\n\nTwo common mistakes are investigating loss landscapes and Hessians (1) *'without considering L2 regularization'* on (2) *'clean datasets'*. However, note that NNs are optimized with L2 regularization on augmented datasets. Therefore, it is appropriate to visualize *'NLL + L2'* on *'augmented datasets'*. Measuring criteria without L2 on clean datasets would give incorrect results.\n\n\n\n## Citation\n\nIf you find this useful, please consider citing 📑 the paper and starring 🌟 this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: [xxxnell](https://twitter.com/xxxnell)) with any comments or feedback.\n\n```\n@inproceedings{park2022how,\n  title={How Do Vision Transformers Work?},\n  author={Namuk Park and Songkuk Kim},\n  booktitle={International Conference on Learning Representations},\n  year={2022}\n}\n```\n\n\n## License\n\nAll code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the [PyTorch Image Models](https://github.com/rwightman/pytorch-image-models) and [Vision Transformer - Pytorch](https://github.com/lucidrains/vit-pytorch) which are Apache 2.0 and MIT licensed.\n\nCopyright the maintainers.\n\n\n\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fxxxnell%2Fhow-do-vits-work","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fxxxnell%2Fhow-do-vits-work","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fxxxnell%2Fhow-do-vits-work/lists"}