{"id":15293054,"url":"https://github.com/emalagoli92/van-classification-tensorflow","last_synced_at":"2026-01-05T07:48:36.477Z","repository":{"id":98585815,"uuid":"591953662","full_name":"EMalagoli92/VAN-Classification-TensorFlow","owner":"EMalagoli92","description":"TensorFlow 2.X reimplementation of Visual Attention Network, Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.","archived":false,"fork":false,"pushed_at":"2023-03-05T15:41:39.000Z","size":1228,"stargazers_count":1,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-03-08T07:23:34.078Z","etag":null,"topics":["computer-vision","deep-learning","image-classification","python","pytorch","tensorflow","transformers"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/EMalagoli92.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-01-22T13:21:41.000Z","updated_at":"2023-03-05T15:50:08.000Z","dependencies_parsed_at":null,"dependency_job_id":"2aba929e-cf33-4c9c-ae05-c5b17645809e","html_url":"https://github.com/EMalagoli92/VAN-Classification-TensorFlow","commit_stats":{"total_commits":154,"total_committers":1,"mean_commits":154.0,"dds":0.0,"last_synced_commit":"2de8b657f2d0ab915b130ef7f0f09418135f8c60"},"previous_names":[],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EMalagoli92%2FVAN-Classification-TensorFlow","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EMalagoli92%2FVAN-Classification-TensorFlow/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EMalagoli92%2FVAN-Classification-TensorFlow/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EMalagoli92%2FVAN-Classification-TensorFlow/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/EMalagoli92","download_url":"https://codeload.github.com/EMalagoli92/VAN-Classification-TensorFlow/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":245273130,"owners_count":20588526,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["computer-vision","deep-learning","image-classification","python","pytorch","tensorflow","transformers"],"created_at":"2024-09-30T16:38:46.407Z","updated_at":"2026-01-05T07:48:36.439Z","avatar_url":"https://github.com/EMalagoli92.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003cdiv align=\"center\"\u003e\n\n  \u003ca href=\"https://www.tensorflow.org\"\u003e![TensorFLow](https://img.shields.io/badge/TensorFlow-2.X-orange?style=for-the-badge) \n  \u003ca href=\"https://github.com/EMalagoli92/VAN-Classification-TensorFlow/blob/main/LICENSE\"\u003e![License](https://img.shields.io/github/license/EMalagoli92/VAN-Classification-TensorFlow?style=for-the-badge) \n  \u003ca href=\"https://www.python.org\"\u003e![Python](https://img.shields.io/badge/python-%3E%3D%203.9-blue?style=for-the-badge)\u003c/a\u003e  \n  \n\u003c/div\u003e\n\n# VAN-Classification-TensorFlow\nTensorFlow 2.X reimplementation of [Visual Attention Network](https://arxiv.org/abs/2202.09741v5), Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.\n- Exact TensorFlow reimplementation of official PyTorch repo, including `timm` modules used by authors, preserving models and layers structure.\n- ImageNet pretrained weights ported from PyTorch official implementation.\n\n## Table of contents\n- [Abstract](#abstract)\n- [Results](#results)\n- [Installation](#installation)\n- [Usage](#usage)\n- [Acknowledgement](#acknowledgement)\n- [Citations](#citations)\n- [License](#license)\n\n\u003cdiv id=\"abstract\"/\u003e\n\n## Abstract\n*While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, the authors propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. The authors further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers (ViTs) and convolutional neural networks (CNNs) with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc.*\n\n\n![Alt text](https://github.com/EMalagoli92/VAN-Classification-TensorFlow/blob/main/assets/images/Comparsion.png?raw=true) \n\u003cp align = \"center\"\u003e\u003csub\u003eFigure 1. Compare with different vision backbones on ImageNet-1K validation set.\u003c/sub\u003e\u003c/p\u003e\n\n\n![Alt text](https://github.com/EMalagoli92/VAN-Classification-TensorFlow/blob/main/assets/images/decomposition.png?raw=true)\n\u003cp align = \"center\"\u003e\u003csub\u003eFigure 2. Decomposition diagram of large-kernel convolution. A standard convolution can be decomposed into three parts: a depth-wise convolution (DW-Conv), a depth-wise dilation convolution (DW-D-Conv) and a 1×1 convolution (1×1 Conv).\u003c/sub\u003e\u003c/p\u003e\n\n\n![Alt text](https://github.com/EMalagoli92/VAN-Classification-TensorFlow/blob/main/assets/images/LKA.png?raw=true)\n\u003cp align = \"center\"\u003e\u003csub\u003eFigure 3. The structure of different modules: (a) the proposed Large Kernel Attention (LKA); (b) non-attention module; (c) the self-attention module (d) a stage of our Visual Attention Network (VAN). CFF means convolutional feed-forward network. The difference between (a) and (b) is the element-wise multiply. It is worth noting that (c) is designed for 1D sequences.\u003c/sub\u003e\u003c/p\u003e\n\n\n\u003cdiv id=\"results\"/\u003e\n\n## Results\nTensorFlow implementation and ImageNet ported weights have been compared to the official PyTorch implementation on [ImageNet-V2](https://www.tensorflow.org/datasets/catalog/imagenet_v2) test set.\n\n### Models pre-trained on ImageNet-1K\n| Configuration  | Resolution | Top-1 (Original) | Top-1 (Ported) | Top-5 (Original) | Top-5 (Ported) | #Params\n| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |\n| VAN-B0 | 224x224 | 0.59 | 0.59 | 0.81 | 0.81 | 4.1M |\n| VAN-B1 | 224x224 | 0.64 | 0.64 | 0.84 | 0.84 | 13.9M |\n| VAN-B2 | 224x224 | 0.69 | 0.69 | 0.88 | 0.88 | 26.6M |\n| VAN-B3 | 224x224 | 0.71 | 0.71 | 0.89 | 0.89 | 44.8M |\n\nMetrics difference: `0`.\n\n\n\u003cdiv id=\"installation\"/\u003e\n\n## Installation\n- Install from PyPI.\n```\npip install van-classification-tensorflow\n```\n- Install from GitHub.\n```\npip install git+https://github.com/EMalagoli92/VAN-Classification-TensorFlow\n```\n- Clone the repo and install necessary packages.\n```\ngit clone https://github.com/EMalagoli92/VAN-Classification-TensorFlow.git\npip install -r requirements.txt\n```\nTested on *Ubuntu 20.04.4 LTS x86_64*, *python 3.9.7*.\n\n\u003cdiv id=\"usage\"/\u003e\n\n## Usage\n- Define a custom VAN configuration.\n```python\nfrom van_classification_tensorflow import VAN\n\n# Define a custom VAN configuration\nmodel = VAN(\n    in_chans=3,\n    num_classes=1000,\n    embed_dims=[64, 128, 256, 512],\n    mlp_ratios=[4, 4, 4, 4],\n    drop_rate=0.0,\n    drop_path_rate=0.0,\n    depths=[3, 4, 6, 3],\n    num_stages=4,\n    include_top=True,\n    classifier_activation=\"softmax\",\n    data_format=\"channels_last\",\n)\n```\n- Use a predefined VAN configuration.\n```python\nfrom van_classification_tensorflow import VAN\n\nmodel = VAN(\n    configuration=\"van_b0\", data_format=\"channels_last\", classifier_activation=\"softmax\"\n)\n\nmodel.build((None, 224, 224, 3))\nprint(model.summary())\n```\n```\nModel: \"van_b0\"\n_________________________________________________________________\n Layer (type)                Output Shape              Param #   \n=================================================================\n patch_embed1 (OverlapPatchE  ((None, 32, 56, 56),     4864      \n mbed)                        (),                                \n                              ())                                \n                                                                 \n block1/0 (Block)            (None, 32, 56, 56)        25152     \n                                                                 \n block1/1 (Block)            (None, 32, 56, 56)        25152     \n                                                                 \n block1/2 (Block)            (None, 32, 56, 56)        25152     \n                                                                 \n norm1 (LayerNorm_)          (None, 3136, 32)          64        \n                                                                 \n patch_embed2 (OverlapPatchE  ((None, 64, 28, 28),     18752     \n mbed)                        (),                                \n                              ())                                \n                                                                 \n block2/0 (Block)            (None, 64, 28, 28)        89216     \n                                                                 \n block2/1 (Block)            (None, 64, 28, 28)        89216     \n                                                                 \n block2/2 (Block)            (None, 64, 28, 28)        89216     \n                                                                 \n norm2 (LayerNorm_)          (None, 784, 64)           128       \n                                                                 \n patch_embed3 (OverlapPatchE  ((None, 160, 14, 14),    92960     \n mbed)                        (),                                \n                              ())                                \n                                                                 \n block3/0 (Block)            (None, 160, 14, 14)       303040    \n                                                                 \n block3/1 (Block)            (None, 160, 14, 14)       303040    \n                                                                 \n block3/2 (Block)            (None, 160, 14, 14)       303040    \n                                                                 \n block3/3 (Block)            (None, 160, 14, 14)       303040    \n                                                                 \n block3/4 (Block)            (None, 160, 14, 14)       303040    \n                                                                 \n norm3 (LayerNorm_)          (None, 196, 160)          320       \n                                                                 \n patch_embed4 (OverlapPatchE  ((None, 256, 7, 7),      369920    \n mbed)                        (),                                \n                              ())                                \n                                                                 \n block4/0 (Block)            (None, 256, 7, 7)         755200    \n                                                                 \n block4/1 (Block)            (None, 256, 7, 7)         755200    \n                                                                 \n norm4 (LayerNorm_)          (None, 49, 256)           512       \n                                                                 \n head (Linear_)              (None, 1000)              257000    \n                                                                 \n pred (Activation)           (None, 1000)              0         \n                                                                 \n=================================================================\nTotal params: 4,113,224\nTrainable params: 4,105,800\nNon-trainable params: 7,424\n_________________________________________________________________\n```\n- Train from scratch the model.\n```python\n# Example\nmodel.compile(\n    optimizer=\"sgd\",\n    loss=\"sparse_categorical_crossentropy\",\n    metrics=[\"accuracy\", \"sparse_top_k_categorical_accuracy\"],\n)\nmodel.fit(x, y)\n```\n- Use ported ImageNet pretrained weights.\n```python\n# Example\nfrom van_classification_tensorflow import VAN\n\nmodel = VAN(\n    configuration=\"van_b1\",\n    pretrained=True,\n    include_top=True,\n    classifier_activation=\"softmax\",\n)\ny_pred = model(image)\n```\n\n- Use ported ImageNet pretrained weights for feature extraction (`include_top=False`).\n```python\nimport tensorflow as tf\n\nfrom van_classification_tensorflow import VAN\n\n# Get Features\ninputs = tf.keras.layers.Input(shape=(224, 224, 3), dtype=\"float32\")\nfeatures = VAN(configuration=\"van_b0\", pretrained=True, include_top=False)(inputs)\n\n\n# Custom classification\nnum_classes = 10\noutputs = tf.keras.layers.Dense(num_classes, activation=\"softmax\")(features)\nmodel = tf.keras.models.Model(inputs=inputs, outputs=outputs)\n```\n\n\u003cdiv id=\"acknowledgement\"/\u003e\n\n## Acknowledgement\n[VAN-Classification](https://github.com/Visual-Attention-Network/VAN-Classification) (Official PyTorch implementation).\n\n\n\u003cdiv id=\"citations\"/\u003e\n\n## Citations\n```bibtex\n@article{guo2022visual,\n  title={Visual Attention Network},\n  author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},\n  journal={arXiv preprint arXiv:2202.09741},\n  year={2022}\n}\n```\n\n\n\u003cdiv id=\"license\"/\u003e\n\n## License\nThis work is made available under the [MIT License](https://github.com/EMalagoli92/VAN-Classification-TensorFlow/blob/main/LICENSE).\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Femalagoli92%2Fvan-classification-tensorflow","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Femalagoli92%2Fvan-classification-tensorflow","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Femalagoli92%2Fvan-classification-tensorflow/lists"}