{"id":13710892,"url":"https://github.com/taki0112/vit-tensorflow","last_synced_at":"2025-09-07T21:40:35.548Z","repository":{"id":37705074,"uuid":"473880076","full_name":"taki0112/vit-tensorflow","owner":"taki0112","description":"Vision Transformer Cookbook with Tensorflow","archived":false,"fork":false,"pushed_at":"2022-03-28T07:19:21.000Z","size":7693,"stargazers_count":334,"open_issues_count":3,"forks_count":52,"subscribers_count":5,"default_branch":"main","last_synced_at":"2025-05-20T02:06:50.975Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/taki0112.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2022-03-25T05:28:34.000Z","updated_at":"2025-05-12T03:31:40.000Z","dependencies_parsed_at":"2022-07-13T03:10:41.520Z","dependency_job_id":null,"html_url":"https://github.com/taki0112/vit-tensorflow","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/taki0112/vit-tensorflow","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/taki0112%2Fvit-tensorflow","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/taki0112%2Fvit-tensorflow/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/taki0112%2Fvit-tensorflow/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/taki0112%2Fvit-tensorflow/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/taki0112","download_url":"https://codeload.github.com/taki0112/vit-tensorflow/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/taki0112%2Fvit-tensorflow/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":274101618,"owners_count":25222446,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","status":"online","status_checked_at":"2025-09-07T02:00:09.463Z","response_time":67,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-08-02T23:01:01.925Z","updated_at":"2025-09-07T21:40:35.499Z","avatar_url":"https://github.com/taki0112.png","language":"Python","funding_links":[],"categories":["Computer Vision"],"sub_categories":["General Purpose CV"],"readme":"# Vision Transformer Cookbook with Tensorflow\n\u003cimg src=\"./images/vit.gif\" width=\"500px\"\u003e\u003c/img\u003e\n\n## Author\n* [Junho Kim](http://bit.ly/jhkim_resume)\n\n### Acknowledgement\n* Appreciate to [@lucidrains](https://github.com/lucidrains) for his permission to release this repository.\n* [vit-pytorch](https://github.com/lucidrains/vit-pytorch)\n\n## Table of Contents\n- [Vision Transformer - Tensorflow](#vision-transformer---tensorflow)\n- [Usage](#usage)\n- [Parameters](#parameters)\n- [Distillation](#distillation)\n- [Deep ViT](#deep-vit)\n- [CaiT](#cait)\n- [Token-to-Token ViT](#token-to-token-vit)\n- [CCT](#cct)\n- [Cross ViT](#cross-vit)\n- [PiT](#pit)\n- [LeViT](#levit)\n- [CvT](#cvt)\n- [Twins SVT](#twins-svt)\n- [CrossFormer](#crossformer)\n- [RegionViT](#regionvit)\n- [ScalableViT](#scalablevit)\n- [NesT](#nest)\n- [MobileViT](#mobilevit)\n- [Masked Autoencoder](#masked-autoencoder)\n- [Simple Masked Image Modeling](#simple-masked-image-modeling)\n- [Masked Patch Prediction](#masked-patch-prediction)\n- [Adaptive Token Sampling](#adaptive-token-sampling)\n- [Patch Merger](#patch-merger)\n- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)\n- [Parallel ViT](#parallel-vit)\n- [FAQ](#faq)\n- [Resources](#resources)\n\n## Vision Transformer - Tensorflow ( \u003e= 2.3.0)\nImplementation of \u003ca href=\"https://openreview.net/pdf?id=YicbFdNTTy\"\u003eVision Transformer\u003c/a\u003e, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Tensorflow. Significance is further explained in \u003ca href=\"https://www.youtube.com/watch?v=TrdevFK_am4\"\u003eYannic Kilcher's\u003c/a\u003e video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.\n\n## Usage\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow import ViT\n\nv = ViT(\n    image_size = 256,\n    patch_size = 32,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([1, 256, 256, 3])\n\npreds = v(img) # (1, 1000)\n```\n\n## Parameters\n\n- `image_size`: int.  \nImage size. If you have rectangular images, make sure your image size is the maximum of the width and height\n- `patch_size`: int.  \nNumber of patches. `image_size` must be divisible by `patch_size`.  \nThe number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.\n- `num_classes`: int.  \nNumber of classes to classify.\n- `dim`: int.  \nLast dimension of output tensor after linear transformation `nn.Linear(..., dim)`.\n- `depth`: int.  \nNumber of Transformer blocks.\n- `heads`: int.  \nNumber of heads in Multi-head Attention layer. \n- `mlp_dim`: int.  \nDimension of the MLP (FeedForward) layer. \n- `dropout`: float between `[0, 1]`, default `0.`.  \nDropout rate. \n- `emb_dropout`: float between `[0, 1]`, default `0`.  \nEmbedding dropout rate.\n- `pool`: string, either `cls` token pooling or `mean` pooling\n\n\n## Distillation\n\n\u003cimg src=\"./images/distill.png\" width=\"300px\"\u003e\u003c/img\u003e\n\nA recent \u003ca href=\"https://arxiv.org/abs/2012.12877\"\u003epaper\u003c/a\u003e has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.\n\nex. distilling from Resnet50 (or any teacher) to a vision transformer\n\n```python\nimport tensorflow as tf\n\u003c\u003c\u003c\u003c\u003c\u003c\u003c HEAD\n\n=======\n\u003e\u003e\u003e\u003e\u003e\u003e\u003e 4d94a87a458fa952a88f56d1e188eef5524a895a\nfrom vit_tensorflow.distill import DistillableViT, DistillWrapper\n\nteacher = tf.keras.applications.resnet50.ResNet50()\n\nv = DistillableViT(\n    image_size = 256,\n    patch_size = 32,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 8,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\ndistiller = DistillWrapper(\n    student = v,\n    teacher = teacher,\n    temperature = 3,           # temperature of distillation\n    alpha = 0.5,               # trade between main loss and distillation loss\n    hard = False               # whether to use soft or hard distillation\n)\n\nimg = tf.random.normal([2, 256, 256, 3])\nlabels = tf.random.uniform(shape=[2, ], minval=0, maxval=1000, dtype=tf.int32)\nlabels = tf.one_hot(labels, depth=1000, axis=-1)\n\nloss = distiller([img, labels])\n\n# after lots of training above ...\n\npred = v(img) # (2, 1000)\n```\n\n\n## Deep ViT\n\nThis \u003ca href=\"https://arxiv.org/abs/2103.11886\"\u003epaper\u003c/a\u003e notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the \u003ca href=\"https://github.com/lucidrains/x-transformers#talking-heads-attention\"\u003eTalking Heads\u003c/a\u003e paper from NLP.\n\nYou can use it as follows\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.deepvit import DeepViT\n\nv = DeepViT(\n    image_size = 256,\n    patch_size = 32,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([1, 256, 256, 3])\n\npreds = v(img) # (1, 1000)\n```\n\n## CaiT\n\n\u003ca href=\"https://arxiv.org/abs/2103.17239\"\u003eThis paper\u003c/a\u003e also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.\n\nThey also add \u003ca href=\"https://github.com/lucidrains/x-transformers#talking-heads-attention\"\u003eTalking Heads\u003c/a\u003e, noting improvements\n\nYou can use this scheme as follows\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.cait import CaiT\n\nv = CaiT(\n    image_size = 256,\n    patch_size = 32,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 12,             # depth of transformer for patch to patch attention only\n    cls_depth = 2,          # depth of cross attention of CLS tokens to patch\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1,\n    layer_dropout = 0.05    # randomly dropout 5% of the layers\n)\n\nimg = tf.random.normal([1, 256, 256, 3])\n\npreds = v(img) # (1, 1000)\n```\n\n## Token-to-Token ViT\n\n\u003cimg src=\"./images/t2t.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003ca href=\"https://arxiv.org/abs/2101.11986\"\u003eThis paper\u003c/a\u003e proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the `ViT` as follows.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.t2t import T2TViT\n\nv = T2TViT(\n    dim = 512,\n    image_size = 224,\n    depth = 5,\n    heads = 8,\n    mlp_dim = 512,\n    num_classes = 1000,\n    t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module\n)\n\nimg = tf.random.normal([1, 224, 224, 3])\n\npreds = v(img) # (1, 1000)\n```\n\n## CCT\n\n\u003cimg src=\"https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003ca href=\"https://arxiv.org/abs/2104.05704\"\u003eCCT\u003c/a\u003e proposes compact transformers\nby using convolutions instead of patching and performing sequence pooling. This\nallows for CCT to have high accuracy and a low number of parameters.\n\nYou can use this with two methods\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.cct import CCT\n\n\u003c\u003c\u003c\u003c\u003c\u003c\u003c HEAD\nmodel = CCT(\n        img_size=224,\n        embedding_dim=384,\n        n_conv_layers=2,\n        kernel_size=7,\n        stride=2,\n        padding=3,\n        pooling_kernel_size=3,\n        pooling_stride=2,\n        pooling_padding=1,\n        num_layers=14,\n        num_heads=6,\n        mlp_radio=3.,\n        num_classes=1000,\n        positional_embedding='learnable', # ['sine', 'learnable', 'none']\n        )\n=======\ncct = CCT(\n    img_size = (224, 448),\n    embedding_dim = 384,\n    n_conv_layers = 2,\n    kernel_size = 7,\n    stride = 2,\n    padding = 3,\n    pooling_kernel_size = 3,\n    pooling_stride = 2,\n    pooling_padding = 1,\n    num_layers = 14,\n    num_heads = 6,\n    mlp_radio = 3.,\n    num_classes = 1000,\n    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']\n)\n\nimg = tf.random.normal(shape=[1, 224, 448, 3])\npreds = cct(img) # (1, 1000)\n\n\u003e\u003e\u003e\u003e\u003e\u003e\u003e 4d94a87a458fa952a88f56d1e188eef5524a895a\n```\n\nAlternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`\nwhich pre-define the number of layers, number of attention heads, the mlp ratio,\nand the embedding dimension.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.cct import cct_14\n\n\u003c\u003c\u003c\u003c\u003c\u003c\u003c HEAD\nmodel = cct_14(\n        img_size=224,\n        n_conv_layers=1,\n        kernel_size=7,\n        stride=2,\n        padding=3,\n        pooling_kernel_size=3,\n        pooling_stride=2,\n        pooling_padding=1,\n        num_classes=1000,\n        positional_embedding='learnable', # ['sine', 'learnable', 'none']  \n        )\n=======\ncct = cct_14(\n    img_size = 224,\n    n_conv_layers = 1,\n    kernel_size = 7,\n    stride = 2,\n    padding = 3,\n    pooling_kernel_size = 3,\n    pooling_stride = 2,\n    pooling_padding = 1,\n    num_classes = 1000,\n    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']\n)\n\u003e\u003e\u003e\u003e\u003e\u003e\u003e 4d94a87a458fa952a88f56d1e188eef5524a895a\n```\n\u003ca href=\"https://github.com/SHI-Labs/Compact-Transformers\"\u003eOfficial\nRepository\u003c/a\u003e includes links to pretrained model checkpoints.\n\n\n## Cross ViT\n\n\u003cimg src=\"./images/cross_vit.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003ca href=\"https://arxiv.org/abs/2103.14899\"\u003eThis paper\u003c/a\u003e proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.cross_vit import CrossViT\n\nv = CrossViT(\n    image_size = 256,\n    num_classes = 1000,\n    depth = 4,               # number of multi-scale encoding blocks\n    sm_dim = 192,            # high res dimension\n    sm_patch_size = 16,      # high res patch size (should be smaller than lg_patch_size)\n    sm_enc_depth = 2,        # high res depth\n    sm_enc_heads = 8,        # high res heads\n    sm_enc_mlp_dim = 2048,   # high res feedforward dimension\n    lg_dim = 384,            # low res dimension\n    lg_patch_size = 64,      # low res patch size\n    lg_enc_depth = 3,        # low res depth\n    lg_enc_heads = 8,        # low res heads\n    lg_enc_mlp_dim = 2048,   # low res feedforward dimensions\n    cross_attn_depth = 2,    # cross attention rounds\n    cross_attn_heads = 8,    # cross attention heads\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([1, 256, 256, 3])\n\npred = v(img) # (1, 1000)\n```\n\n## PiT\n\n\u003cimg src=\"./images/pit.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003ca href=\"https://arxiv.org/abs/2103.16302\"\u003eThis paper\u003c/a\u003e proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.pit import PiT\n\nv = PiT(\n    image_size = 224,\n    patch_size = 14,\n    dim = 256,\n    num_classes = 1000,\n    depth = (3, 3, 3),     # list of depths, indicating the number of rounds of each stage before a downsample\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\n# forward pass now returns predictions and the attention maps\n\nimg = tf.random.normal([1, 224, 224, 3])\n\npreds = v(img) # (1, 1000)\n```\n\n## LeViT\n\n\u003cimg src=\"./images/levit.png\" width=\"300px\"\u003e\u003c/img\u003e\n\n\u003ca href=\"https://arxiv.org/abs/2104.01136\"\u003eThis paper\u003c/a\u003e proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.\n\n\u003ca href=\"https://github.com/facebookresearch/LeViT\"\u003eOfficial repository\u003c/a\u003e\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.levit import LeViT\n\nlevit = LeViT(\n    image_size = 224,\n    num_classes = 1000,\n    stages = 3,             # number of stages\n    dim = (256, 384, 512),  # dimensions at each stage\n    depth = 4,              # transformer of depth 4 at each stage\n    heads = (4, 6, 8),      # heads at each stage\n    mlp_mult = 2,\n    dropout = 0.1\n)\n\nimg = tf.random.normal([1, 224, 224, 3])\n\nlevit(img) # (1, 1000)\n```\n\n## CvT\n\n\u003cimg src=\"./images/cvt.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003ca href=\"https://arxiv.org/abs/2103.15808\"\u003eThis paper\u003c/a\u003e proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.cvt import CvT\n\nv = CvT(\n    num_classes = 1000,\n    s1_emb_dim = 64,        # stage 1 - dimension\n    s1_emb_kernel = 7,      # stage 1 - conv kernel\n    s1_emb_stride = 4,      # stage 1 - conv stride\n    s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size\n    s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride\n    s1_heads = 1,           # stage 1 - heads\n    s1_depth = 1,           # stage 1 - depth\n    s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor\n    s2_emb_dim = 192,       # stage 2 - (same as above)\n    s2_emb_kernel = 3,\n    s2_emb_stride = 2,\n    s2_proj_kernel = 3,\n    s2_kv_proj_stride = 2,\n    s2_heads = 3,\n    s2_depth = 2,\n    s2_mlp_mult = 4,\n    s3_emb_dim = 384,       # stage 3 - (same as above)\n    s3_emb_kernel = 3,\n    s3_emb_stride = 2,\n    s3_proj_kernel = 3,\n    s3_kv_proj_stride = 2,\n    s3_heads = 4,\n    s3_depth = 10,\n    s3_mlp_mult = 4,\n    dropout = 0.\n)\n\nimg = tf.random.normal([1, 224, 224, 3])\n\npred = v(img) # (1, 1000)\n```\n\n## Twins SVT\n\n\u003cimg src=\"./images/twins_svt.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2104.13840\"\u003epaper\u003c/a\u003e proposes mixing local and global attention, along with position encoding generator (proposed in \u003ca href=\"https://arxiv.org/abs/2102.10882\"\u003eCPVT\u003c/a\u003e) and global average pooling, to achieve the same results as \u003ca href=\"https://arxiv.org/abs/2103.14030\"\u003eSwin\u003c/a\u003e, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.twins_svt import TwinsSVT\n\nmodel = TwinsSVT(\n    num_classes = 1000,       # number of output classes\n    s1_emb_dim = 64,          # stage 1 - patch embedding projected dimension\n    s1_patch_size = 4,        # stage 1 - patch size for patch embedding\n    s1_local_patch_size = 7,  # stage 1 - patch size for local attention\n    s1_global_k = 7,          # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper\n    s1_depth = 1,             # stage 1 - number of transformer blocks (local attn -\u003e ff -\u003e global attn -\u003e ff)\n    s2_emb_dim = 128,         # stage 2 (same as above)\n    s2_patch_size = 2,\n    s2_local_patch_size = 7,\n    s2_global_k = 7,\n    s2_depth = 1,\n    s3_emb_dim = 256,         # stage 3 (same as above)\n    s3_patch_size = 2,\n    s3_local_patch_size = 7,\n    s3_global_k = 7,\n    s3_depth = 5,\n    s4_emb_dim = 512,         # stage 4 (same as above)\n    s4_patch_size = 2,\n    s4_local_patch_size = 7,\n    s4_global_k = 7,\n    s4_depth = 4,\n    peg_kernel_size = 3,      # positional encoding generator kernel size\n    dropout = 0.              # dropout\n)\n\nimg = tf.random.normal([1, 224, 224, 3])\n\npred = model(img) # (1, 1000)\n```\n\n## RegionViT\n\n\u003cimg src=\"./images/regionvit.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003cimg src=\"./images/regionvit2.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003ca href=\"https://arxiv.org/abs/2106.02689\"\u003eThis paper\u003c/a\u003e proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.\n\nYou can use it as follows\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.regionvit import RegionViT\n\nmodel = RegionViT(\n    dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage\n    depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage\n    window_size = 7,                # window size, which should be either 7 or 14\n    num_classes = 1000,             # number of output classes\n    tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models\n    use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance\n)\n\nimg = tf.random.normal([1, 224, 224, 3])\n\npred = model(img) # (1, 1000)\n```\n\n## CrossFormer\n\n\u003cimg src=\"./images/crossformer.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003cimg src=\"./images/crossformer2.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2108.00154\"\u003epaper\u003c/a\u003e beats PVT and Swin using alternating local and global attention. The global attention is done across the windowing dimension for reduced complexity, much like the scheme used for axial attention.\n\nThey also have cross-scale embedding layer, which they shown to be a generic layer that can improve all vision transformers. Dynamic relative positional bias was also formulated to allow the net to generalize to images of greater resolution.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.crossformer import CrossFormer\n\nmodel = CrossFormer(\n    num_classes = 1000,                # number of output classes\n    dim = (64, 128, 256, 512),         # dimension at each stage\n    depth = (2, 2, 8, 2),              # depth of transformer at each stage\n    global_window_size = (8, 4, 2, 1), # global window sizes at each stage\n    local_window_size = 7,             # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)\n)\n\nimg = tf.random.normal([1, 224, 224, 3])\n\npred = model(img) # (1, 1000)\n```\n\n## ScalableViT\n\n\u003cimg src=\"./images/scalable-vit-1.png\" width=\"400px\"\u003e\u003c/img\u003e\n\n\u003cimg src=\"./images/scalable-vit-2.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis Bytedance AI \u003ca href=\"https://arxiv.org/abs/2203.10790\"\u003epaper\u003c/a\u003e proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (`reduction_factor`), while modulating the dimension of the queries and keys (`ssa_dim_key`). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).\n\nThey make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.\n\nYou can use it as follows (ex. ScalableViT-S)\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.scalable_vit import ScalableViT\n\nmodel = ScalableViT(\n    num_classes = 1000,\n    dim = 64,                               # starting model dimension. at every stage, dimension is doubled\n    heads = (2, 4, 8, 16),                  # number of attention heads at each stage\n    depth = (2, 2, 20, 2),                  # number of transformer blocks at each stage\n    ssa_dim_key = (40, 40, 40, 32),         # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)\n    reduction_factor = (8, 4, 2, 1),        # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)\n    window_size = (64, 32, None, None),     # window size of the IWSA at each stage. None means no windowing needed\n    dropout = 0.1,                          # attention and feedforward dropout\n)\n\nimg = tf.random.normal([1, 256, 256, 3])\n\npreds = model(img) # (1, 1000)\n```\n\n## NesT\n\n\u003cimg src=\"./images/nest.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2105.12723\"\u003epaper\u003c/a\u003e decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.\n\nYou can use it with the following code (ex. NesT-T)\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.nest import NesT\n\nnest = NesT(\n    image_size = 224,\n    patch_size = 4,\n    dim = 96,\n    heads = 3,\n    num_hierarchies = 3,        # number of hierarchies\n    block_repeats = (2, 2, 8),  # the number of transformer blocks at each heirarchy, starting from the bottom\n    num_classes = 1000\n)\n\nimg = tf.random.normal([1, 224, 224, 3])\n\npred = nest(img) # (1, 1000)\n```\n\n## MobileViT\n\n\u003cimg src=\"./images/mbvit.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2110.02178\"\u003epaper\u003c/a\u003e introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different\nperspective for the global processing of information with transformers.\n\nYou can use it with the following code (ex. mobilevit_xs)\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.mobile_vit import MobileViT\n\nmbvit_xs = MobileViT(\n    image_size = (256, 256),\n    dims = [96, 120, 144],\n    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],\n    num_classes = 1000\n)\n\nimg = tf.random.normal([1, 256, 256, 3])\n\npred = mbvit_xs(img) # (1, 1000)\n```\n\n## Simple Masked Image Modeling\n\n\u003cimg src=\"./images/simmim.png\" width=\"400px\"/\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2111.09886\"\u003epaper\u003c/a\u003e proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.\n\nYou can use this as follows\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow import ViT\nfrom vit_tensorflow.simmim import SimMIM\n\nv = ViT(\n    image_size = 256,\n    patch_size = 32,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 8,\n    mlp_dim = 2048\n)\n\nmim = SimMIM(\n    encoder = v,\n    masking_ratio = 0.5  # they found 50% to yield the best results\n)\n\nimages = tf.random.normal([8, 256, 256, 3])\n\nloss = mim(images)\n\n# that's all!\n# do the above in a for loop many times with a lot of images and your vision transformer will learn\n\n```\n\n\n## Masked Autoencoder\n\n\u003cimg src=\"./images/mae.png\" width=\"400px\"/\u003e\n\nA new \u003ca href=\"https://arxiv.org/abs/2111.06377\"\u003eKaiming He paper\u003c/a\u003e proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.\n\n\u003ca href=\"https://www.youtube.com/watch?v=LKixq2S2Pz8\"\u003eDeepReader quick paper review\u003c/a\u003e\n\n\u003ca href=\"https://www.youtube.com/watch?v=Dp6iICL2dVI\"\u003eAI Coffeebreak with Letitia\u003c/a\u003e\n\nYou can use it with the following code\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow import ViT, MAE\n\nv = ViT(\n    image_size = 256,\n    patch_size = 32,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 8,\n    mlp_dim = 2048\n)\n\nmae = MAE(\n    encoder = v,\n    masking_ratio = 0.75,   # the paper recommended 75% masked patches\n    decoder_dim = 512,      # paper showed good results with just 512\n    decoder_depth = 6       # anywhere from 1 to 8\n)\n\nimages = tf.random.normal([8, 256, 256, 3])\n\nloss = mae(images)\n\n# that's all!\n# do the above in a for loop many times with a lot of images and your vision transformer will learn\n\n```\n\n## Masked Patch Prediction\n\nThanks to \u003ca href=\"https://github.com/zankner\"\u003eZach\u003c/a\u003e, you can train using the original masked patch prediction task presented in the paper, with the following code.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow import ViT\nfrom vit_tensorflow.mpp import MPP\n\nmodel = ViT(\n    image_size=256,\n    patch_size=32,\n    num_classes=1000,\n    dim=1024,\n    depth=6,\n    heads=8,\n    mlp_dim=2048,\n    dropout=0.1,\n    emb_dropout=0.1\n)\n\nmpp_trainer = MPP(\n    transformer=model,\n    patch_size=32,\n    dim=1024,\n    mask_prob=0.15,          # probability of using token in masked prediction task\n    random_patch_prob=0.30,  # probability of randomly replacing a token being used for mpp\n    replace_prob=0.50,       # probability of replacing a token being used for mpp with the mask token\n)\n\ndef sample_unlabelled_images():\n    return tf.random.normal([20, 256, 256, 3])\n\nfor _ in range(100):\n    with tf.GradientTape() as tape:\n        images = sample_unlabelled_images()\n        loss = mpp_trainer(images)\n```\n\n## Adaptive Token Sampling\n\n\u003cimg src=\"./images/ats.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2111.15667\"\u003epaper\u003c/a\u003e proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.ats_vit import ViT\n\nv = ViT(\n    image_size = 256,\n    patch_size = 16,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([4, 256, 256, 3])\n\npreds = v(img) # (4, 1000)\n\n# you can also get a list of the final sampled patch ids\n# a value of -1 denotes padding\n\npreds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, \u003c=8)\n```\n\n## Patch Merger\n\n\n\u003cimg src=\"./images/patch_merger.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2202.12015\"\u003epaper\u003c/a\u003e proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.vit_with_patch_merger import ViT\n\nv = ViT(\n    image_size = 256,\n    patch_size = 16,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 12,\n    heads = 8,\n    patch_merge_layer = 6,        # at which transformer layer to do patch merging\n    patch_merge_num_tokens = 8,   # the output number of tokens from the patch merge\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([4, 256, 256, 3])\n\npreds = v(img) # (4, 1000)\n```\n\nOne can also use the `PatchMerger` module by itself\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.vit_with_patch_merger import PatchMerger\n\nmerger = PatchMerger(\n    dim = 1024,\n    num_tokens_out = 8   # output number of tokens\n)\n\nfeatures = tf.random.normal([4, 256, 1024]) # (batch, num tokens, dimension)\n\nout = merger(features) # (4, 8, 1024)\n```\n\n## Vision Transformer for Small Datasets\n\n\u003cimg src=\"./images/vit_for_small_datasets.png\" width=\"400px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2112.13492\"\u003epaper\u003c/a\u003e proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the `LSA` with the learned temperature and masking out of a token's attention to itself.\n\nYou can use as follows:\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.vit_for_small_dataset import ViT\n\nv = ViT(\n    image_size = 256,\n    patch_size = 16,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([4, 256, 256, 3])\n\npreds = v(img) # (1, 1000)\n```\n\nYou can also use the `SPT` from this paper as a standalone module\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.vit_for_small_dataset import SPT\n\nspt = SPT(\n    dim = 1024,\n    patch_size = 16,\n    channels = 3\n)\n\nimg = tf.random.normal([4, 256, 256, 3])\n\ntokens = spt(img) # (4, 256, 1024)\n```\n\n## Parallel ViT\n\n\u003cimg src=\"./images/parallel-vit.png\" width=\"350px\"\u003e\u003c/img\u003e\n\nThis \u003ca href=\"https://arxiv.org/abs/2203.09795\"\u003epaper\u003c/a\u003e propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.\n\nYou can try this variant as follows\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow.parallel_vit import ViT\n\nv = ViT(\n    image_size = 256,\n    patch_size = 16,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 8,\n    mlp_dim = 2048,\n    num_parallel_branches = 2,  # in paper, they claimed 2 was optimal\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([4, 256, 256, 3])\n\npreds = v(img) # (4, 1000)\n```\n\n## FAQ\n\n- How do I pass in non-square images?\n\nYou can already pass in non-square images - you just have to make sure your height and width is less than or equal to the `image_size`, and both divisible by the `patch_size`\n\nex.\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow import ViT\n\nv = ViT(\n    image_size = 256,\n    patch_size = 32,\n    num_classes = 1000,\n    dim = 1024,\n    depth = 6,\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([1, 256, 128, 3]) # \u003c-- not a square\n\npreds = v(img) # (1, 1000)\n```\n\n- How do I pass in non-square patches?\n\n```python\nimport tensorflow as tf\nfrom vit_tensorflow import ViT\n\nv = ViT(\n    num_classes = 1000,\n    image_size = (256, 128),  # image size is a tuple of (height, width)\n    patch_size = (32, 16),    # patch size is a tuple of (height, width)\n    dim = 1024,\n    depth = 6,\n    heads = 16,\n    mlp_dim = 2048,\n    dropout = 0.1,\n    emb_dropout = 0.1\n)\n\nimg = tf.random.normal([1, 256, 128, 3])\n\npreds = v(img)\n```\n\n## Resources\n\nComing from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.\n\n1. \u003ca href=\"http://jalammar.github.io/illustrated-transformer/\"\u003eIllustrated Transformer\u003c/a\u003e - Jay Alammar\n\n2. \u003ca href=\"http://peterbloem.nl/blog/transformers\"\u003eTransformers from Scratch\u003c/a\u003e  - Peter Bloem\n\n3. \u003ca href=\"https://nlp.seas.harvard.edu/2018/04/03/attention.html\"\u003eThe Annotated Transformer\u003c/a\u003e - Harvard NLP\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftaki0112%2Fvit-tensorflow","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ftaki0112%2Fvit-tensorflow","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftaki0112%2Fvit-tensorflow/lists"}