{"id":17349562,"url":"https://github.com/jiangtaoxie/SoT","last_synced_at":"2025-02-26T02:31:54.986Z","repository":{"id":44030804,"uuid":"359854953","full_name":"jiangtaoxie/SoT","owner":"jiangtaoxie","description":"SoT: Delving Deeper into Classification Head for Transformer","archived":false,"fork":false,"pushed_at":"2021-12-24T08:42:16.000Z","size":2357,"stargazers_count":48,"open_issues_count":2,"forks_count":6,"subscribers_count":4,"default_branch":"main","last_synced_at":"2024-12-13T19:11:44.068Z","etag":null,"topics":["deep-learning","pytorch"],"latest_commit_sha":null,"homepage":"https://peihuali.org/SoT/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/jiangtaoxie.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2021-04-20T14:58:30.000Z","updated_at":"2024-11-18T15:45:55.000Z","dependencies_parsed_at":"2022-09-13T14:40:44.893Z","dependency_job_id":null,"html_url":"https://github.com/jiangtaoxie/SoT","commit_stats":null,"previous_names":["jiangtaoxie/so-vit"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jiangtaoxie%2FSoT","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jiangtaoxie%2FSoT/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jiangtaoxie%2FSoT/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jiangtaoxie%2FSoT/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/jiangtaoxie","download_url":"https://codeload.github.com/jiangtaoxie/SoT/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":240780758,"owners_count":19856418,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","pytorch"],"created_at":"2024-10-15T16:56:21.429Z","updated_at":"2025-02-26T02:31:54.976Z","avatar_url":"https://github.com/jiangtaoxie.png","language":"Python","funding_links":[],"categories":["Table of Contents"],"sub_categories":["微软Transformer霸榜模型"],"readme":"# SoT: Delving Deeper into Classification Head for Transformer\n\n\n\u003cdiv\u003e\n\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u003cimg src=\"images/overview.jpg\" width=\"100%\"/\u003e\n\u003c/div\u003e\n\n## Contents\n\n1. [Introduction](#introduction)\n2. [Installation](#installation)\n3. [Usage](#usage)\n4. [Classification results on CV tasks](#classification-results-on-cv-tasks)\n5. [Classification results on NLP tasks](#classification-results-on-NLP-tasks)\n6. [Visualization](#visualization)\n7. [Change log](#change-log)\n8. [Acknowledgments](#Acknowlegments)\n9. [Contact](#contact)\n\n\n\n## Introduction\n\nThis repository is the official  implementation of \"[SoT: Delving Deeper into Classification Head for Transformer](https://arxiv.org/pdf/2104.10935.pdf)\". It\ncontains the source code under **PyTorch** framework and models for image classification and text classification tasks.\n\n### Citation\n\nPlease consider cite the paper if it's useful for you. \n\n    @articles{SoT,\n        author = {Jiangtao Xie, Ruiren Zeng, Qilong Wang, Ziqi Zhou, Peihua Li},\n        title = {SoT: Delving Deeper into Classification Head for Transformer},\n        booktitle = {arXiv:2104.10935v2},\n        year = {2021}\n    }\n\n### Motivation and Contributions\n\nFor classification tasks whether in CV or NLP, the current works based on pure transformer architecture pay little attention to the classification head, applying **Classification token** (ClassT) solely in the classifier,  however neglecting the **Word tokens** (WordT) which contains rich information. In our experiments, we show the ClassT and WordT are highly complementary, and the fusion of all tokens can further boost the performance. Therefore, we propose a novel classification paradigm by jointly utilizing ClassT and WordT, where the multiheaded global cross-covariance pooling with singluar value power normalization is proposed for effectively harness the rich information of WordT. We evaluate our proposed classfication scheme on the both CV and NLP tasks, achieving the very competitive performance with the counterparts.\n\n## Installation\n\n- clone\n```sh\ngit clone https://github.com/jiangtaoxie/SoT.git\ncd SoT/\n```\n- install dependencies\n```sh\npip install -r requirments.txt\n```\nmain libs: torch(\u003e=1.7.0) | timm(==0.3.4) | apex (alternative)\n- install\n```sh\npython setup.py install \n```\n\n## Usage\n\n### Prepare dataset\n\nPlease prepare the dataset as the following file structure:\n```sh\n.\n├── train\n│   ├── class1\n│   │   ├── class1_001.jpg\n│   │   ├── class1_002.jpg\n|   |   └── ...\n│   ├── class2\n│   ├── class3\n│   ├── ...\n│   ├── ...\n│   └── classN\n└── val\n    ├── class1\n    │   ├── class1_001.jpg\n    │   ├── class1_002.jpg\n    |   └── ...\n    ├── class2\n    ├── class3\n    ├── ...\n    ├── ...\n    └── classN\n```\n\n\n### Using our proposed SoT model\n\n- Training from scracth:\n\nYou can train the models of SoT family by using the command:\n\n```sh\nsh ./distributed_train.sh $NODE_NUM $DATA_ROOT --model $MODEL_NAME -b $BATCH_SIZE --lr  $INIT_LR\\\n--weight-decay $WEIGHT_DECAY \\\n--img-size $RESOLUTION \\\n--amp \n```\nBasic hyper-parameter of our SoT:\n\n| Hyper-parameter|SoT-Tiny | SoT-Small | SoT-Base |\n|:--:|:-------|:----------:|:------:|\n| Batch size | 1024 | 1024 | 512 |\n| Init. LR | 1e-3 | 1e-3 | 5e-4 |\n| Weight Decay | 3e-2 | 3e-2 | 6.5e-2 |\n\nAlso, we provide the `shell` files in `./scripts` for reproducing conveniently, you can run:\n```\nsh ./scripts/train_SoT_Tiny.sh # reproduce SoT-Tiny\nsh ./scripts/train_SoT_Small.sh # reproduce SoT-Small\nsh ./scripts/train_SoT_Base.sh # reproduce SoT-Base\n```\n\n- Evaluation\n\nOn validation set of ImageNet-1K:\n\n```sh\npython main.py $DATA_ROOT $MODEL_NAME --b 256 --eval_checkpoint $CHECKPOINT_PATH\n```\n\nOn ImageNet-A:\n\n```sh\npython main.py $DATA_ROOT $MODEL_NAME --b 256 --eval_checkpoint $CHECKPOINT_PATH --IN_A\n```\n\nThe `$MODEL_NAME` can be `SoT_Tiny`/`SoT_Small`/`SoT_Base`\n\n### Using our proposed classification head in your architecture\n\n- import the sot_src package\n```python\nfrom sot_src.model import Classifier, OnlyVisualTokensClassifier\n```\n- define the classification head\n```python\nclassification_head_config = dict(\n    type='MGCrP',\n    fusion_type='sum_fc',\n    args=dict(\n        dim=256,\n        num_heads=6,\n        wr_dim=14,\n        normalization=dict(\n            type='svPN'\n            alpha=0.5,\n            iterNum=1,\n            svNum=1,\n            regular=None, # or nn.Dropout(0.5)\n            input_dim=14,\n        ),\n    ),\n)\n\nclassifier = Classifier(classification_head_config)\n```\nNotes: \n- if your backbone without classification token, please use `OnlyVisualTokensClassifier` to replace `Classifier`\n- key arguments:\n    - dim: equal to the embedding dimension\n    - wr_dim: dimension of W,R; you can control the final representation dimension by adjusting it\n    - regular: you can use dropout regularization to alleviate the overfitting\n\nBesides, we provide the implementation based on the [DeiT]() and [Swin-Transformer]() in CV tasks and [BERT]() in NLP tasks for reference.\n\n### Using the proposed visual tokens in your architecture\n\nYou can also use the proposed TokenEmbedding module implemented by the DenseNet block like:\n\n```python\nfrom sot_src import TokenEmbed\n\npatch_embed_config = dict(\n    type='DenseNet',\n    embedding_dim=64,\n    large_output=False, # When the resulotion of input image is 224, Ture for the 56x56 output, False for 14x14 output\n)\n\npatch_embed = TokenEmbed(patch_embed_config)\n```\n\n\n\n## Classification results on CV tasks\n\nAccuracy (single crop 224x224, %) on the validation set of ImageNet-1K and ImageNet-A\n\n### Our SoT family\n\n| Backbone | ImageNet Top-1 Acc. |ImageNet-A Top-1 Acc. | #Params (M) | GFLOPs | Weight |\n|:--:|:-------:|:----------:|:------:|:------:|:------:|\n| SoT-Tiny | 80.3 | 21.5 | 7.7 | 2.5 | [Coming soon]() |\n| SoT-Small | 82.7 | 31.8 | 26.9 | 5.8 | [Coming soon]() |\n| SoT-Base | 83.5 |  34.6 | 76.8 | 14.5 | [Coming soon]() |\n\n### DeiT family\n\n| Backbone | ImageNet Top-1 Acc. |ImageNet-A Top-1 Acc. | #Params (M) | GFLOPs | Weight |\n|:--:|:-------:|:----------:|:------:|:------:|:------:|\n| DeiT-T | 72.2 | 7.3 | 5.7 | 1.3 | [model](https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth) |\n| DeiT-T + ours | 78.6 | 17.5 | 7.0 | 2.3 | [Coming soon]() |\n| DeiT-S | 79.8 | 18.9 | 22.1 | 4.6 | [model](https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth) |\n| DeiT-S + ours | 82.7 | 31.8 | 26.9 | 5.8 | [Coming soon]() |\n| DeiT-B | 81.8 | 27.4 | 86.6 | 17.6 |[model](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth) |\n| DeiT-B + ours | 82.9 | 29.1 | 94.9 | 18.2 | [Coming soon]() |\n\n### Swin Transformer family\n\n| Backbone | ImageNet Top-1 Acc. |ImageNet-A Top-1 Acc. | #Params (M) | GFLOPs | Weight |\n|:--:|:-------:|:----------:|:------:|:------:|:------:|\n| Swin-T | 81.3 | 21.6 | 28.3 | 4.5 |[model](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth) |\n| Swin-T + ours | 83.0 | 33.5 | 31.6 | 6.0 | [Coming soon]() |\n| Swin-B | 83.5 | 35.8 | 87.8 | 15.4 | [model](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth) |\n| Swin-B + ours | 84.0 | 42.9 | 95.9 | 16.9 | [Coming soon]() |\n\nNotes:\n- `+ours` means we adopt the proposed classification head and token embedding module upon the other architectures.\n- We report the accuracy training from scracth on ImageNet-1K.\n\n## Classification results on NLP tasks\n\nAccuracy (Top-1, %) on the 4 selected tasks from General Language Understanding Evaluation ([GLUE](https://gluebenchmark.com/)) benchmark.\n\n- CoLA (The Corpus of Linguistic Acceptability): the task is to judge whether a English sentence is grammatical or not.\n- RTE (The Recognizing Textual Entailment datasets): the task is to determine whether the given pair of sentences is entailment or not.\n- MNLI (The Multi-Genre Natural Language Inference Corpus): the task is to classify the given pair of sentences from multi-source is entailment, contradiction or neutral.\n- QNLI (Qusetion-answering Natural Language Inference Corpus): the task is to decide the question-answer sentence pair is entailment or not.\n\n| Backbone | CoLA | RTE | MNLI | QNLI | Weight |\n|:--:|:-------:|:----------:|:------:|:------:|:------:|\n| GPT | 54.32 | 63.17 | 82.10 | 86.36 | [model](https://github.com/openai/finetune-transformer-lm/tree/master/model) |\n| GPT + ours| 57.25 | 65.35 | 82.41 | 87.13 | [Coming soon]() |\n||\n| BERT-base | 54.82 | 67.15 | 83.47 | 90.11 | [model](https://huggingface.co/bert-base-cased) |\n| BERT-base + ours | 58.03 | 69.31 | 84.20 | 90.78 |[Coming soon]() |\n| BERT-large | 60.63 | 73.65 | 85.90 | 91.82 | [model](https://huggingface.co/bert-large-cased) |\n| BERT-large + ours | 61.82 | 75.09 | 86.46 | 92.37 |[Coming soon]() |\n||\n| SpanBERT-base | 57.48 | 73.65 | 85.53 | 92.71 |[model](https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf_base.tar.gz)|\n| SpanBERT-base + ours | 63.77 | 77.26 | 86.13 | 93.31 | [Coming soon]()|\n| SpanBERT-large | 64.32 | 78.34 | 87.89 | 94.22 |[model](https://dl.fbaipublicfiles.com/fairseq/models/spanbert_hf.tar.gz)|\n| SpanBERT-large + ours | 65.94 | 79.79 | 88.16 | 94.49 |[Coming soon]()|\n||\n| RoBERTa-base | 61.58 | 77.60 | 87.50 | 92.70 |[model](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)|\n| RoBERTa-base + ours | 65.28 | 80.50 | 87.90 | 93.10 |[Coming soon]()|\n| RoBERTa-large | 67.98 | 86.60 | 90.20| 94.70 |[model](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)|\n| RoBERTa-large + ours | 70.90 | 88.10 | 90.50 | 95.00 |[Coming soon]()|\n\n## Visualization\n \nWe make the further analysis by visualizing the models for CV and NLP tasks, where the SoT-Tiny and BERT-base are used as the backbone for each task respectively. We compare three variants base on the SoT-Tiny and BERT-base as follows:\n- **ClassT**: only classification token is used for classifier\n- **WordT**: only word tokens are used for classifier\n- **ClassT+WordT**: both classification token and word tokens are used for classifier based on the sum scheme.\n\n\u003cp align=\"center\" style=\"color:rgb(255,0,0);\"\u003e\u0026radic;:\u003cfont color=\"black\"\u003e correct prediction;\u003c/font\u003e \u0026#10007;: \u003cfont color=\"black\"\u003eincorrect prediction\u003c/font\u003e\u003c/p\u003e\n\n\n\u003cdiv\u003e\n\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u003cimg src=\"images/vis.png\" width=\"100%\"/\u003e\n\u003c/div\u003e\n\nWe can see the **ClassT** is more suitable for classifying the categories associated with the backgrounds and the whole context. The **WordT** performs classfication primarily based on some local discriminative regions. Our **ClassT+WordT** can make fully use of merits of both word tokens and classfication token, which can focus on the most important regions for better classficaiton by exploiting both local and global information.\n\n\n\u003cdiv\u003e\n\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u0026emsp;\u003cimg src=\"images/nlp_vis.png\" width=\"100%\"/\u003e\n\u003c/div\u003e\n\nWe selected some examples from CoLA task, which aims to judge whether an English sentence is grammatical or not. The greener background color denotes stronger impact of the word to the classification, while the bluer implies weaker one. We can see the proposed **ClassT+WordT** can highlight all important words in sentence while the others two fails, which can help to boost the performance of classification.\n\n## Change log\n\n\n## Acknowledgments\n\n\npytorch: https://github.com/pytorch/pytorch\n\ntimm: https://github.com/rwightman/pytorch-image-models\n\nT2T-ViT: https://github.com/yitu-opensource/T2T-ViT\n\n## Contact\n\n**If you have any questions or suggestions, please contact me**\n\n`jiangtaoxie@mail.dlut.edu.cn`; `coke990921@mail.dlut.edu.cn`\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjiangtaoxie%2FSoT","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjiangtaoxie%2FSoT","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjiangtaoxie%2FSoT/lists"}