Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/francescosaveriozuppichini/vit
Implementing Vi(sion)T(transformer)
https://github.com/francescosaveriozuppichini/vit
computer-vision deep-learning
Last synced: 20 days ago
JSON representation
Implementing Vi(sion)T(transformer)
- Host: GitHub
- URL: https://github.com/francescosaveriozuppichini/vit
- Owner: FrancescoSaverioZuppichini
- Created: 2021-01-01T15:14:05.000Z (almost 4 years ago)
- Default Branch: main
- Last Pushed: 2023-03-19T11:55:02.000Z (over 1 year ago)
- Last Synced: 2023-10-20T19:53:25.041Z (about 1 year ago)
- Topics: computer-vision, deep-learning
- Homepage:
- Size: 2.17 MB
- Stars: 274
- Watchers: 8
- Forks: 49
- Open Issues: 4
-
Metadata Files:
- Readme: README.ipynb
Awesome Lists containing this project
README
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implementing Vi(sual)T(transformer) in PyTorch\n",
"\n",
"Hi guys, happy new year! Today we are going to implement the famous **Vi**(sual)**T**(transformer) proposed in [AN IMAGE IS WORTH 16X16 WORDS:\n",
"TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE](https://arxiv.org/pdf/2010.11929.pdf).\n",
"\n",
"\n",
"Code is here, an interactive version of this article can be downloaded from [here](https://github.com/FrancescoSaverioZuppichini/ViT).\n",
"\n",
"ViT will be soon available on my **new computer vision library called [glasses](https://github.com/FrancescoSaverioZuppichini/glasses)**\n",
"\n",
"This is a technical tutorial, not your normal medium post where you find out about the top 5 secret pandas functions to make you rich. \n",
"\n",
"So, before beginning, I highly recommend you to:\n",
"\n",
"- have a look at the amazing [The Illustrated Transformer\n",
"](https://jalammar.github.io/illustrated-transformer/) website\n",
"- watch [Yannic Kilcher video about ViT](https://www.youtube.com/watch?v=TrdevFK_am4&t=1000s)\n",
"- read [Einops](https://github.com/arogozhnikov/einops/) doc\n",
"\n",
"So, ViT uses a normal transformer (the one proposed in [Attention is All You Need](https://arxiv.org/abs/1706.03762)) that works on images. But, how?\n",
"\n",
"The following picture shows ViT's architecture\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/ViT.png?raw=true)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The input image is decomposed into 16x16 flatten patches (the image is not in scale). Then they are embedded using a normal fully connected layer, a special `cls` token is added in front of them and the `positional encoding` is summed. The resulting tensor is passed first into a standard Transformer and then to a classification head. That's it. \n",
"\n",
"The article is structure into the following sections:\n",
"\n",
"- Data\n",
"- Patches Embeddings\n",
" - CLS Token\n",
" - Position Embedding\n",
"- Transformer\n",
" - Attention\n",
" - Residuals\n",
" - MLP\n",
" - TransformerEncoder\n",
"- Head\n",
"- ViT\n",
"\n",
"We are going to implement the model block by block with a bottom-up approach. We can start by importing all the required packages"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from torch import nn\n",
"from torch import Tensor\n",
"from PIL import Image\n",
"from torchvision.transforms import Compose, Resize, ToTensor\n",
"from einops import rearrange, reduce, repeat\n",
"from einops.layers.torch import Rearrange, Reduce\n",
"from torchsummary import summary\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nothing fancy here, just PyTorch + stuff"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data\n",
"\n",
"First of all, we need a picture, a cute cat works just fine :)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"img = Image.open('./cat.jpg')\n",
"\n",
"fig = plt.figure()\n",
"plt.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/output_5_1.png?raw=true)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we need to preprocess it"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 3, 224, 224])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# resize to imagenet size \n",
"transform = Compose([Resize((224, 224)), ToTensor()])\n",
"x = transform(img)\n",
"x = x.unsqueeze(0) # add batch dim\n",
"x.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Patches Embeddings\n",
"\n",
"The first step is to break-down the image in multiple patches and flatten them.\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/Patches.png?raw=true)\n",
"\n",
"Quoting from the paper:\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/paper1.png?raw=true)\n",
"\n",
"\n",
"This can be easily done using einops. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"patch_size = 16 # 16 pixels\n",
"pathes = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we need to project them using a normal linear layer\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/PatchesProjected.png?raw=true)\n",
"\n",
"We can create a `PatchEmbedding` class to keep our code nice and clean"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 196, 768])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class PatchEmbedding(nn.Module):\n",
" def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):\n",
" self.patch_size = patch_size\n",
" super().__init__()\n",
" self.projection = nn.Sequential(\n",
" # break-down the image in s1 x s2 patches and flat them\n",
" Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),\n",
" nn.Linear(patch_size * patch_size * in_channels, emb_size)\n",
" )\n",
" \n",
" def forward(self, x: Tensor) -> Tensor:\n",
" x = self.projection(x)\n",
" return x\n",
" \n",
"PatchEmbedding()(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note** After checking out the original implementation, I found out that the authors are using a Conv2d layer instead of a Linear one for performance gain. This is obtained by using a kernel_size and stride equal to the `patch_size`. Intuitively, the convolution operation is applied to each patch individually. So, we have to first apply the conv layer and then flat the resulting images."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 196, 768])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class PatchEmbedding(nn.Module):\n",
" def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):\n",
" self.patch_size = patch_size\n",
" super().__init__()\n",
" self.projection = nn.Sequential(\n",
" # using a conv layer instead of a linear one -> performance gains\n",
" nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),\n",
" Rearrange('b e (h) (w) -> b (h w) e'),\n",
" )\n",
" \n",
" def forward(self, x: Tensor) -> Tensor:\n",
" x = self.projection(x)\n",
" return x\n",
" \n",
"PatchEmbedding()(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CLS Token"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next step is to add the `cls token` and the position embedding. The `cls token` is just a number placed in from of **each** sequence (of projected patches)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 197, 768])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class PatchEmbedding(nn.Module):\n",
" def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):\n",
" self.patch_size = patch_size\n",
" super().__init__()\n",
" self.projection = nn.Sequential(\n",
" # using a conv layer instead of a linear one -> performance gains\n",
" nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),\n",
" Rearrange('b e (h) (w) -> b (h w) e'),\n",
" )\n",
" \n",
" self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))\n",
" \n",
" def forward(self, x: Tensor) -> Tensor:\n",
" b, _, _, _ = x.shape\n",
" x = self.projection(x)\n",
" cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)\n",
" # prepend the cls token to the input\n",
" x = torch.cat([cls_tokens, x], dim=1)\n",
" return x\n",
" \n",
"PatchEmbedding()(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`cls_token` is a torch Parameter randomly initialized, in the forward the method it is copied `b` (batch) times and prepended before the projected patches using `torch.cat`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Position Embedding\n",
"\n",
"So far, the model has no idea about the original position of the patches. We need to pass this spatial information. This can be done in different ways, in ViT we let the model learn it. The position embedding is just a tensor of shape `N_PATCHES + 1 (token), EMBED_SIZE` that is added to the projected patches.\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/PatchesPositionEmbedding.png?raw=true)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 197, 768])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class PatchEmbedding(nn.Module):\n",
" def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):\n",
" self.patch_size = patch_size\n",
" super().__init__()\n",
" self.projection = nn.Sequential(\n",
" # using a conv layer instead of a linear one -> performance gains\n",
" nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),\n",
" Rearrange('b e (h) (w) -> b (h w) e'),\n",
" )\n",
" self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))\n",
" self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))\n",
"\n",
" \n",
" def forward(self, x: Tensor) -> Tensor:\n",
" b, _, _, _ = x.shape\n",
" x = self.projection(x)\n",
" cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)\n",
" # prepend the cls token to the input\n",
" x = torch.cat([cls_tokens, x], dim=1)\n",
" # add position embedding\n",
" x += self.positions\n",
" return x\n",
" \n",
"PatchEmbedding()(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We added the position embedding in the `.positions` field and sum it to the patches in the `.forward` function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Transformer\n",
"\n",
"Now we need the implement Transformer. In ViT only the Encoder is used, the architecture is visualized in the following picture.\n",
"\n",
"\n",
"\n",
"Let's start with the Attention part"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Attention\n",
"\n",
"So, the attention takes three inputs, the famous queries, keys, and values, and computes the attention matrix using queries and values and use it to \"attend\" to the values. In this case, we are using multi-head attention meaning that the computation is split across n heads with smaller input size.\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/TransformerBlockAttention.png?raw=true)\n",
"\n",
"We can use `nn.MultiHadAttention` from PyTorch or implement our own. For completeness I will show how it looks like:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 197, 768])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):\n",
" super().__init__()\n",
" self.emb_size = emb_size\n",
" self.num_heads = num_heads\n",
" self.keys = nn.Linear(emb_size, emb_size)\n",
" self.queries = nn.Linear(emb_size, emb_size)\n",
" self.values = nn.Linear(emb_size, emb_size)\n",
" self.att_drop = nn.Dropout(dropout)\n",
" self.projection = nn.Linear(emb_size, emb_size)\n",
" self.scaling = (self.emb_size // num_heads) ** -0.5\n",
"\n",
" def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:\n",
" # split keys, queries and values in num_heads\n",
" queries = rearrange(self.queries(x), \"b n (h d) -> b h n d\", h=self.num_heads)\n",
" keys = rearrange(self.keys(x), \"b n (h d) -> b h n d\", h=self.num_heads)\n",
" values = rearrange(self.values(x), \"b n (h d) -> b h n d\", h=self.num_heads)\n",
" # sum up over the last axis\n",
" energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len\n",
" if mask is not None:\n",
" fill_value = torch.finfo(torch.float32).min\n",
" energy.mask_fill(~mask, fill_value)\n",
" \n",
" att = F.softmax(energy * self.scaling, dim=-1)\n",
" att = self.att_drop(att)\n",
" # sum up over the third axis\n",
" out = torch.einsum('bhal, bhlv -> bhav ', att, values)\n",
" out = rearrange(out, \"b h n d -> b n (h d)\")\n",
" out = self.projection(out)\n",
" return out\n",
" \n",
"patches_embedded = PatchEmbedding()(x)\n",
"MultiHeadAttention()(patches_embedded).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, step by step. We have 4 fully connected layers, one for queries, keys, values, and a final one dropout.\n",
"\n",
"Okay, the idea (really go and read [The Illustrated Transformer\n",
"](https://jalammar.github.io/illustrated-transformer/)) is to use the product between the queries and the keys to knowing \"how much\" each element is the sequence in important with the rest. Then, we use this information to scale the values.\n",
"\n",
"The `forward` method takes as input the queries, keys, and values from the previous layer and projects them using the three linear layers. Since we implementing multi heads attention, we have to rearrange the result in multiple heads. \n",
"\n",
"This is done by using `rearrange` from einops. \n",
"\n",
"*Queries, Keys and Values* are always the same, so for simplicity, I have only one input (`x`). \n",
"\n",
"```python\n",
"queries = rearrange(self.queries(x), \"b n (h d) -> b h n d\", h=self.n_heads)\n",
"keys = rearrange(self.keys(x), \"b n (h d) -> b h n d\", h=self.n_heads)\n",
"values = rearrange(self.values(x), \"b n (h d) -> b h n d\", h=self.n_heads)\n",
"```\n",
"\n",
"The resulting keys, queries, and values have a shape of `BATCH, HEADS, SEQUENCE_LEN, EMBEDDING_SIZE`.\n",
"\n",
"To compute the attention matrix we first have to perform matrix multiplication between queries and keys, a.k.a sum up over the last axis. This can be easily done using `torch.einsum`\n",
"\n",
"```python\n",
"energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys\n",
"```\n",
"\n",
"The resulting vector has the shape `BATCH, HEADS, QUERY_LEN, KEY_LEN`. Then the attention is finally the softmax of the resulting vector divided by a scaling factor based on the size of the embedding. \n",
"\n",
"Lastly, we use the attention to scale the values\n",
"\n",
"```python\n",
"torch.einsum('bhal, bhlv -> bhav ', att, values)\n",
"```\n",
"\n",
"and we obtain a vector of size `BATCH HEADS VALUES_LEN, EMBEDDING_SIZE`. We concat the heads together and we finally return the results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note** we can use a single matrix to compute in one shot `queries, keys and values`. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 197, 768])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):\n",
" super().__init__()\n",
" self.emb_size = emb_size\n",
" self.num_heads = num_heads\n",
" # fuse the queries, keys and values in one matrix\n",
" self.qkv = nn.Linear(emb_size, emb_size * 3)\n",
" self.att_drop = nn.Dropout(dropout)\n",
" self.projection = nn.Linear(emb_size, emb_size)\n",
" \n",
" def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:\n",
" # split keys, queries and values in num_heads\n",
" qkv = rearrange(self.qkv(x), \"b n (h d qkv) -> (qkv) b h n d\", h=self.num_heads, qkv=3)\n",
" queries, keys, values = qkv[0], qkv[1], qkv[2]\n",
" # sum up over the last axis\n",
" energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len\n",
" if mask is not None:\n",
" fill_value = torch.finfo(torch.float32).min\n",
" energy.mask_fill(~mask, fill_value)\n",
" \n",
" scaling = self.emb_size ** (1/2)\n",
" att = F.softmax(energy, dim=-1) / scaling\n",
" att = self.att_drop(att)\n",
" # sum up over the third axis\n",
" out = torch.einsum('bhal, bhlv -> bhav ', att, values)\n",
" out = rearrange(out, \"b h n d -> b n (h d)\")\n",
" out = self.projection(out)\n",
" return out\n",
" \n",
"patches_embedded = PatchEmbedding()(x)\n",
"MultiHeadAttention()(patches_embedded).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Residuals\n",
"\n",
"The transformer block has residuals connection\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/TransformerBlockAttentionRes.png?raw=true)\n",
"\n",
"We can create a nice wrapper to perform the residual addition, it will be handy later on"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class ResidualAdd(nn.Module):\n",
" def __init__(self, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
" \n",
" def forward(self, x, **kwargs):\n",
" res = x\n",
" x = self.fn(x, **kwargs)\n",
" x += res\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MLP\n",
"\n",
"The attention's output is passed to a fully connected layer composed of two layers that upsample by a factor of `expansion` the input\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class FeedForwardBlock(nn.Sequential):\n",
" def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):\n",
" super().__init__(\n",
" nn.Linear(emb_size, expansion * emb_size),\n",
" nn.GELU(),\n",
" nn.Dropout(drop_p),\n",
" nn.Linear(expansion * emb_size, emb_size),\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just a quick side note. I don't know why but I've never seen people subclassing `nn.Sequential` to avoid writing the `forward` method. Start doing it, this is how object programming works!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Finally**, we can create the Transformer Encoder Block\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`ResidualAdd` allows us to define this block in an elegant way"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class TransformerEncoderBlock(nn.Sequential):\n",
" def __init__(self,\n",
" emb_size: int = 768,\n",
" drop_p: float = 0.,\n",
" forward_expansion: int = 4,\n",
" forward_drop_p: float = 0.,\n",
" ** kwargs):\n",
" super().__init__(\n",
" ResidualAdd(nn.Sequential(\n",
" nn.LayerNorm(emb_size),\n",
" MultiHeadAttention(emb_size, **kwargs),\n",
" nn.Dropout(drop_p)\n",
" )),\n",
" ResidualAdd(nn.Sequential(\n",
" nn.LayerNorm(emb_size),\n",
" FeedForwardBlock(\n",
" emb_size, expansion=forward_expansion, drop_p=forward_drop_p),\n",
" nn.Dropout(drop_p)\n",
" )\n",
" ))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's test it"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 197, 768])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"patches_embedded = PatchEmbedding()(x)\n",
"TransformerEncoderBlock()(patches_embedded).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"you can also PyTorch build-in multi-head attention but it will expect 3 inputs: queries, keys, and values. You can subclass it and pass the same input"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Transformer\n",
"\n",
"In ViT only the Encoder part of the original transformer is used. Easily, the encoder is `L` blocks of `TransformerBlock`.\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class TransformerEncoder(nn.Sequential):\n",
" def __init__(self, depth: int = 12, **kwargs):\n",
" super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Easy peasy!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Head\n",
"\n",
"The last layer is a normal fully connect that gives the class probability. It first performs a basic mean over the whole sequence.\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/ViT/blob/main/images/ClassificationHead.png?raw=true)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"class ClassificationHead(nn.Sequential):\n",
" def __init__(self, emb_size: int = 768, n_classes: int = 1000):\n",
" super().__init__(\n",
" Reduce('b n e -> b e', reduction='mean'),\n",
" nn.LayerNorm(emb_size), \n",
" nn.Linear(emb_size, n_classes))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Vi(sual) T(rasnformer)\n",
"\n",
"We can compose `PatchEmbedding`, `TransformerEncoder` and `ClassificationHead` to create the final ViT architecture."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"class ViT(nn.Sequential):\n",
" def __init__(self, \n",
" in_channels: int = 3,\n",
" patch_size: int = 16,\n",
" emb_size: int = 768,\n",
" img_size: int = 224,\n",
" depth: int = 12,\n",
" n_classes: int = 1000,\n",
" **kwargs):\n",
" super().__init__(\n",
" PatchEmbedding(in_channels, patch_size, emb_size, img_size),\n",
" TransformerEncoder(depth, emb_size=emb_size, **kwargs),\n",
" ClassificationHead(emb_size, n_classes)\n",
" )\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use `torchsummary` to check the number of parameters"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 768, 14, 14] 590,592\n",
" Rearrange-2 [-1, 196, 768] 0\n",
" PatchEmbedding-3 [-1, 197, 768] 0\n",
" LayerNorm-4 [-1, 197, 768] 1,536\n",
" Linear-5 [-1, 197, 2304] 1,771,776\n",
" Dropout-6 [-1, 8, 197, 197] 0\n",
" Linear-7 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-8 [-1, 197, 768] 0\n",
" Dropout-9 [-1, 197, 768] 0\n",
" ResidualAdd-10 [-1, 197, 768] 0\n",
" LayerNorm-11 [-1, 197, 768] 1,536\n",
" Linear-12 [-1, 197, 3072] 2,362,368\n",
" GELU-13 [-1, 197, 3072] 0\n",
" Dropout-14 [-1, 197, 3072] 0\n",
" Linear-15 [-1, 197, 768] 2,360,064\n",
" Dropout-16 [-1, 197, 768] 0\n",
" ResidualAdd-17 [-1, 197, 768] 0\n",
" LayerNorm-18 [-1, 197, 768] 1,536\n",
" Linear-19 [-1, 197, 2304] 1,771,776\n",
" Dropout-20 [-1, 8, 197, 197] 0\n",
" Linear-21 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-22 [-1, 197, 768] 0\n",
" Dropout-23 [-1, 197, 768] 0\n",
" ResidualAdd-24 [-1, 197, 768] 0\n",
" LayerNorm-25 [-1, 197, 768] 1,536\n",
" Linear-26 [-1, 197, 3072] 2,362,368\n",
" GELU-27 [-1, 197, 3072] 0\n",
" Dropout-28 [-1, 197, 3072] 0\n",
" Linear-29 [-1, 197, 768] 2,360,064\n",
" Dropout-30 [-1, 197, 768] 0\n",
" ResidualAdd-31 [-1, 197, 768] 0\n",
" LayerNorm-32 [-1, 197, 768] 1,536\n",
" Linear-33 [-1, 197, 2304] 1,771,776\n",
" Dropout-34 [-1, 8, 197, 197] 0\n",
" Linear-35 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-36 [-1, 197, 768] 0\n",
" Dropout-37 [-1, 197, 768] 0\n",
" ResidualAdd-38 [-1, 197, 768] 0\n",
" LayerNorm-39 [-1, 197, 768] 1,536\n",
" Linear-40 [-1, 197, 3072] 2,362,368\n",
" GELU-41 [-1, 197, 3072] 0\n",
" Dropout-42 [-1, 197, 3072] 0\n",
" Linear-43 [-1, 197, 768] 2,360,064\n",
" Dropout-44 [-1, 197, 768] 0\n",
" ResidualAdd-45 [-1, 197, 768] 0\n",
" LayerNorm-46 [-1, 197, 768] 1,536\n",
" Linear-47 [-1, 197, 2304] 1,771,776\n",
" Dropout-48 [-1, 8, 197, 197] 0\n",
" Linear-49 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-50 [-1, 197, 768] 0\n",
" Dropout-51 [-1, 197, 768] 0\n",
" ResidualAdd-52 [-1, 197, 768] 0\n",
" LayerNorm-53 [-1, 197, 768] 1,536\n",
" Linear-54 [-1, 197, 3072] 2,362,368\n",
" GELU-55 [-1, 197, 3072] 0\n",
" Dropout-56 [-1, 197, 3072] 0\n",
" Linear-57 [-1, 197, 768] 2,360,064\n",
" Dropout-58 [-1, 197, 768] 0\n",
" ResidualAdd-59 [-1, 197, 768] 0\n",
" LayerNorm-60 [-1, 197, 768] 1,536\n",
" Linear-61 [-1, 197, 2304] 1,771,776\n",
" Dropout-62 [-1, 8, 197, 197] 0\n",
" Linear-63 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-64 [-1, 197, 768] 0\n",
" Dropout-65 [-1, 197, 768] 0\n",
" ResidualAdd-66 [-1, 197, 768] 0\n",
" LayerNorm-67 [-1, 197, 768] 1,536\n",
" Linear-68 [-1, 197, 3072] 2,362,368\n",
" GELU-69 [-1, 197, 3072] 0\n",
" Dropout-70 [-1, 197, 3072] 0\n",
" Linear-71 [-1, 197, 768] 2,360,064\n",
" Dropout-72 [-1, 197, 768] 0\n",
" ResidualAdd-73 [-1, 197, 768] 0\n",
" LayerNorm-74 [-1, 197, 768] 1,536\n",
" Linear-75 [-1, 197, 2304] 1,771,776\n",
" Dropout-76 [-1, 8, 197, 197] 0\n",
" Linear-77 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-78 [-1, 197, 768] 0\n",
" Dropout-79 [-1, 197, 768] 0\n",
" ResidualAdd-80 [-1, 197, 768] 0\n",
" LayerNorm-81 [-1, 197, 768] 1,536\n",
" Linear-82 [-1, 197, 3072] 2,362,368\n",
" GELU-83 [-1, 197, 3072] 0\n",
" Dropout-84 [-1, 197, 3072] 0\n",
" Linear-85 [-1, 197, 768] 2,360,064\n",
" Dropout-86 [-1, 197, 768] 0\n",
" ResidualAdd-87 [-1, 197, 768] 0\n",
" LayerNorm-88 [-1, 197, 768] 1,536\n",
" Linear-89 [-1, 197, 2304] 1,771,776\n",
" Dropout-90 [-1, 8, 197, 197] 0\n",
" Linear-91 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-92 [-1, 197, 768] 0\n",
" Dropout-93 [-1, 197, 768] 0\n",
" ResidualAdd-94 [-1, 197, 768] 0\n",
" LayerNorm-95 [-1, 197, 768] 1,536\n",
" Linear-96 [-1, 197, 3072] 2,362,368\n",
" GELU-97 [-1, 197, 3072] 0\n",
" Dropout-98 [-1, 197, 3072] 0\n",
" Linear-99 [-1, 197, 768] 2,360,064\n",
" Dropout-100 [-1, 197, 768] 0\n",
" ResidualAdd-101 [-1, 197, 768] 0\n",
" LayerNorm-102 [-1, 197, 768] 1,536\n",
" Linear-103 [-1, 197, 2304] 1,771,776\n",
" Dropout-104 [-1, 8, 197, 197] 0\n",
" Linear-105 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-106 [-1, 197, 768] 0\n",
" Dropout-107 [-1, 197, 768] 0\n",
" ResidualAdd-108 [-1, 197, 768] 0\n",
" LayerNorm-109 [-1, 197, 768] 1,536\n",
" Linear-110 [-1, 197, 3072] 2,362,368\n",
" GELU-111 [-1, 197, 3072] 0\n",
" Dropout-112 [-1, 197, 3072] 0\n",
" Linear-113 [-1, 197, 768] 2,360,064\n",
" Dropout-114 [-1, 197, 768] 0\n",
" ResidualAdd-115 [-1, 197, 768] 0\n",
" LayerNorm-116 [-1, 197, 768] 1,536\n",
" Linear-117 [-1, 197, 2304] 1,771,776\n",
" Dropout-118 [-1, 8, 197, 197] 0\n",
" Linear-119 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-120 [-1, 197, 768] 0\n",
" Dropout-121 [-1, 197, 768] 0\n",
" ResidualAdd-122 [-1, 197, 768] 0\n",
" LayerNorm-123 [-1, 197, 768] 1,536\n",
" Linear-124 [-1, 197, 3072] 2,362,368\n",
" GELU-125 [-1, 197, 3072] 0\n",
" Dropout-126 [-1, 197, 3072] 0\n",
" Linear-127 [-1, 197, 768] 2,360,064\n",
" Dropout-128 [-1, 197, 768] 0\n",
" ResidualAdd-129 [-1, 197, 768] 0\n",
" LayerNorm-130 [-1, 197, 768] 1,536\n",
" Linear-131 [-1, 197, 2304] 1,771,776\n",
" Dropout-132 [-1, 8, 197, 197] 0\n",
" Linear-133 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-134 [-1, 197, 768] 0\n",
" Dropout-135 [-1, 197, 768] 0\n",
" ResidualAdd-136 [-1, 197, 768] 0\n",
" LayerNorm-137 [-1, 197, 768] 1,536\n",
" Linear-138 [-1, 197, 3072] 2,362,368\n",
" GELU-139 [-1, 197, 3072] 0\n",
" Dropout-140 [-1, 197, 3072] 0\n",
" Linear-141 [-1, 197, 768] 2,360,064\n",
" Dropout-142 [-1, 197, 768] 0\n",
" ResidualAdd-143 [-1, 197, 768] 0\n",
" LayerNorm-144 [-1, 197, 768] 1,536\n",
" Linear-145 [-1, 197, 2304] 1,771,776\n",
" Dropout-146 [-1, 8, 197, 197] 0\n",
" Linear-147 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-148 [-1, 197, 768] 0\n",
" Dropout-149 [-1, 197, 768] 0\n",
" ResidualAdd-150 [-1, 197, 768] 0\n",
" LayerNorm-151 [-1, 197, 768] 1,536\n",
" Linear-152 [-1, 197, 3072] 2,362,368\n",
" GELU-153 [-1, 197, 3072] 0\n",
" Dropout-154 [-1, 197, 3072] 0\n",
" Linear-155 [-1, 197, 768] 2,360,064\n",
" Dropout-156 [-1, 197, 768] 0\n",
" ResidualAdd-157 [-1, 197, 768] 0\n",
" LayerNorm-158 [-1, 197, 768] 1,536\n",
" Linear-159 [-1, 197, 2304] 1,771,776\n",
" Dropout-160 [-1, 8, 197, 197] 0\n",
" Linear-161 [-1, 197, 768] 590,592\n",
"MultiHeadAttention-162 [-1, 197, 768] 0\n",
" Dropout-163 [-1, 197, 768] 0\n",
" ResidualAdd-164 [-1, 197, 768] 0\n",
" LayerNorm-165 [-1, 197, 768] 1,536\n",
" Linear-166 [-1, 197, 3072] 2,362,368\n",
" GELU-167 [-1, 197, 3072] 0\n",
" Dropout-168 [-1, 197, 3072] 0\n",
" Linear-169 [-1, 197, 768] 2,360,064\n",
" Dropout-170 [-1, 197, 768] 0\n",
" ResidualAdd-171 [-1, 197, 768] 0\n",
" Reduce-172 [-1, 768] 0\n",
" LayerNorm-173 [-1, 768] 1,536\n",
" Linear-174 [-1, 1000] 769,000\n",
"================================================================\n",
"Total params: 86,415,592\n",
"Trainable params: 86,415,592\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 364.33\n",
"Params size (MB): 329.65\n",
"Estimated Total Size (MB): 694.56\n",
"----------------------------------------------------------------\n",
"\n"
]
},
{
"data": {
"text/plain": [
"(tensor(86415592), tensor(86415592), tensor(329.6493), tensor(694.5562))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(ViT(), (3, 224, 224), device='cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"et voilà"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"================================================================\n",
"Total params: 86,415,592\n",
"Trainable params: 86,415,592\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 364.33\n",
"Params size (MB): 329.65\n",
"Estimated Total Size (MB): 694.56\n",
"---------------------------------------------------------------\n",
"```\n",
"\n",
"I checked the parameters with other implementations and they are the same!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Conclusions\n",
"\n",
"In this article, we have seen how to implement ViT in a nice, scalable, and customizable way. I hope it was useful.\n",
"\n",
"By the way, I am working on a **new computer vision library called [glasses](https://github.com/FrancescoSaverioZuppichini/glasses), check it out if you like**\n",
"\n",
"Take care :)\n",
"\n",
"Francesco"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}