{"id":20599640,"url":"https://github.com/adamdad/kat","last_synced_at":"2025-05-15T09:04:15.060Z","repository":{"id":257690535,"uuid":"852897652","full_name":"Adamdad/kat","owner":"Adamdad","description":"[ICLR2025] Kolmogorov-Arnold Transformer","archived":false,"fork":false,"pushed_at":"2025-03-23T06:35:19.000Z","size":1756,"stargazers_count":753,"open_issues_count":11,"forks_count":45,"subscribers_count":11,"default_branch":"main","last_synced_at":"2025-04-07T14:01:38.761Z","etag":null,"topics":["computer-vision","kan","kolmogorov-arnold-networks","transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/Adamdad.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-09-05T16:06:28.000Z","updated_at":"2025-04-07T09:06:30.000Z","dependencies_parsed_at":"2024-09-18T05:18:54.796Z","dependency_job_id":"76c5f1bf-2d71-4c5d-aeb8-06e543bd8c19","html_url":"https://github.com/Adamdad/kat","commit_stats":null,"previous_names":["adamdad/kat"],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Adamdad%2Fkat","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Adamdad%2Fkat/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Adamdad%2Fkat/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Adamdad%2Fkat/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/Adamdad","download_url":"https://codeload.github.com/Adamdad/kat/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248919479,"owners_count":21183367,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["computer-vision","kan","kolmogorov-arnold-networks","transformer"],"created_at":"2024-11-16T08:33:38.551Z","updated_at":"2025-04-14T16:47:20.977Z","avatar_url":"https://github.com/Adamdad.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003cdiv align=\"center\"\u003e\n  \u003ctable\u003e\n    \u003ctr\u003e\n      \u003ctd\u003e\u003cimg src=\"assets/logo.webp\" width=\"150\"\u003e\u003c/td\u003e\n      \u003ctd\u003e\u003ch1\u003eKolmogorov–Arnold Transformer: \u003cbr\u003eA PyTorch Implementation\u003c/h1\u003e\u003c/td\u003e\n    \u003c/tr\u003e\n  \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp align=\"center\"\u003e\n\u003ca href=\"https://arxiv.org/abs/2409.10594\" alt=\"arXiv\"\u003e\n    \u003cimg src=\"https://img.shields.io/badge/arXiv-2409.10594-b31b1b.svg?style=flat\" /\u003e\u003c/a\u003e\n      \u003ca href=\"https://pytorch.org/\"\u003e\u003cimg src=\"https://img.shields.io/badge/PyTorch-1.x%20%7C%202.x-673ab7.svg\" alt=\"Tested PyTorch Versions\"\u003e\u003c/a\u003e\n  \u003ca href=\"https://opensource.org/licenses/MIT\"\u003e\u003cimg src=\"https://img.shields.io/badge/License-MIT-4caf50.svg\" alt=\"License\"\u003e\u003c/a\u003e\n\u003c/p\u003e\n\u003cp align=\"center\"\u003e\n\u003cb\u003eICLR 2025\u003c/b\u003e\n\u003c/p\u003e\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"assets/KAT.png\" width=\"300\"\u003e \u003cbr\u003e\nYes, I kan!\n\u003c/p\u003e\n\n🎉 This is a PyTorch/GPU implementation of the paper **Kolmogorov–Arnold Transformer (KAT)**, which replace the MLP layers in transformer with KAN layers.\n\nFor more technical details, please refer to our ICLR'25 paper.\n\n\u003e **Kolmogorov–Arnold Transformer**  \n\u003e 📝[[Paper](https://arxiv.org/abs/2409.10594)] \u003c/\u003e[[code](https://github.com/Adamdad/kat)]  \u003c/\u003e[[Trition/CUDA kernel](https://github.com/Adamdad/rational_kat_cu)]  \n\u003e [Xingyi Yang](https://adamdad.github.io/), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)  \n\u003e National University of Singapore  \n\u003e International Conference on Learning Representations (**ICLR'25**)  \n\n### 🔑 Key Insight:\n\nVanilla ViT + KAN struggle to scale effectively. We introduce the KAT model, which integrates GR-KANs into transformers for large-scale training scenarios like ImageNet, achieving significant performance improvements.\n\n\u003cp align=\"center\"\u003e\n\u003cimg src=\"assets/kat3-1.png\"\u003e \u003cbr\u003e\n\u003c/p\u003e\n\n### 🎯 Our Solutions:\n1. **Base Function**: Replace B-spline to CUDA-implemented Rational.\n2. **Group KAN**: Share weights among groups of edges for efficiency.\n3. **Initialization**: Maintain activation magnitudes across layers.\n\n### ✅ Updates\n- [x] Release the KAT paper, CUDA implementation and IN-1k training code.\n- [x] 🎉🎉🎉🎉 Triton Implementation, on 1D and 2D tasks. This is much easier to install than the CUDA version. Please See [https://github.com/Adamdad/rational_kat_cu](https://github.com/Adamdad/rational_kat_cu).\n- [ ] KAT Detection and segmentation code.\n- [ ] KAT on NLP tasks.\n\n## 🛠️ Installation and Dataset\nPlease find our CUDA implementation in [https://github.com/Adamdad/rational_kat_cu.git](https://github.com/Adamdad/rational_kat_cu.git).\n```shell\n# install torch and other things\npip install timm==1.0.3\npip install wandb # I personally use wandb for results visualizations\ngit clone https://github.com/Adamdad/rational_kat_cu.git\ncd rational_kat_cu\npip install -e .\n```\n\n📦 Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)\n\n```\n│imagenet/\n├──train/\n│  ├── n01440764\n│  │   ├── n01440764_10026.JPEG\n│  │   ├── n01440764_10027.JPEG\n│  │   ├── ......\n│  ├── ......\n├──val/\n│  ├── n01440764\n│  │   ├── ILSVRC2012_val_00000293.JPEG\n│  │   ├── ILSVRC2012_val_00002138.JPEG\n│  │   ├── ......\n│  ├── ......\n```\n\n## Usage\n\nRefer to `example.py` for a detailed use case demonstrating how to use KAT with timm to classify an image.\n\n## 📊 Model Checkpoints\nDownload pre-trained models or access training checkpoints:\n\n|🏷️ Model |⚙️ Setup |📦 Param| 📈 Top1 |🔗 Link|\n| ---|---|---| ---|---|\n|KAT-T| From Scratch|5.7M | 74.6| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch16_224_32487885cf13d2c14e461c9016fac8ad43f7c769171f132530941e930aeb5fe2.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224)\n|KAT-T | From ViT | 5.7M | 75.7| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224-finetune_64f124d003803e4a7e1aba1ba23500ace359b544e8a5f0110993f25052e402fb.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224.vitft)\n|KAT-S| From Scratch| 22.1M | 81.2| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224)\n|KAT-S | From ViT |22.1M | 82.0| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch_224-finetune_3ae087a4c28e2993468eb377d5151350c52c80b2a70cc48ceec63d1328ba58e0.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224.vitft)\n| KAT-B| From Scratch |86.6M| 82.3 | [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_base_patch16_224_abff874d925d756d15cde97303f772a3460ddbd44b9c53fb9ce5cf15be230fb6.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224)\n|  KAT-B | From ViT |86.6M| 82.8 | [link](https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224-finetune_440bf1ead9dd8ecab642078cfb60ae542f1fa33ca65517260501e02c011e38f2.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224.vitft)|\n\n## 🎓Model Training\n\nAll training scripts are under `scripts/`\n```shell\nbash scripts/train_kat_tiny_8x128.sh\n```\n\nIf you want to change the hyper-parameters, can edit\n```shell\n#!/bin/bash\nDATA_PATH=/local_home/dataset/imagenet/\n\nbash ./dist_train.sh 8 $DATA_PATH \\\n    --model kat_tiny_swish_patch16_224 \\ # Rationals are initialized to be swish functions \n    -b 128 \\\n    --opt adamw \\\n    --lr 1e-3 \\\n    --weight-decay 0.05 \\\n    --epochs 300 \\\n    --mixup 0.8 \\\n    --cutmix 1.0 \\\n    --sched cosine \\\n    --smoothing 0.1 \\\n    --drop-path 0.1 \\\n    --aa rand-m9-mstd0.5 \\\n    --remode pixel --reprob 0.25 \\\n    --amp \\\n    --crop-pct 0.875 \\\n    --mean 0.485 0.456 0.406 \\\n    --std 0.229 0.224 0.225 \\\n    --model-ema \\\n    --model-ema-decay 0.9999 \\\n    --output output/kat_tiny_swish_patch16_224 \\\n    --log-wandb\n```\n\n## 🧪 Evaluation\nTo evaluate our `kat_tiny_patch16_224` models, run:\n\n```shell\nDATA_PATH=/local_home/dataset/imagenet/\nCHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth\npython validate.py $DATA_PATH --model kat_tiny_patch16_224 \\\n    --checkpoint $CHECKPOINT_PATH -b 512\n\n###################\nValidating in float32. AMP not enabled.\nLoaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'\nModel kat_tiny_patch16_224 created, param count: 5718328\nData processing configuration for current model + dataset:\n        input_size: (3, 224, 224)\n        interpolation: bicubic\n        mean: (0.485, 0.456, 0.406)\n        std: (0.229, 0.224, 0.225)\n        crop_pct: 0.875\n        crop_mode: center\nTest: [   0/98]  Time: 3.453s (3.453s,  148.28/s)  Loss:  0.6989 (0.6989)  Acc@1:  84.375 ( 84.375)  Acc@5:  96.875 ( 96.875)\n.......\nTest: [  90/98]  Time: 0.212s (0.592s,  864.23/s)  Loss:  1.1640 (1.1143)  Acc@1:  71.875 ( 74.270)  Acc@5:  93.750 ( 92.220)\n * Acc@1 74.558 (25.442) Acc@5 92.390 (7.610)\n--result\n{\n    \"model\": \"kat_tiny_patch16_224\",\n    \"top1\": 74.558,\n    \"top1_err\": 25.442,\n    \"top5\": 92.39,\n    \"top5_err\": 7.61,\n    \"param_count\": 5.72,\n    \"img_size\": 224,\n    \"crop_pct\": 0.875,\n    \"interpolation\": \"bicubic\"\n}\n```\n\n\n## 🙏 Acknowledgments\nWe extend our gratitude to the authors of [rational_activations](https://github.com/ml-research/rational_activations) for their contributions to CUDA rational function implementations that inspired parts of this work. We thank [@yuweihao](https://github.com/yuweihao), [@florinshen](https://github.com/florinshen), [@Huage001](https://github.com/Huage001) and [@yu-rp](https://github.com/yu-rp) for valuable discussions.\n\n## 📚 Bibtex\nIf you use this repository, please cite:\n```bibtex\n@inproceedings{\n  yang2025kolmogorovarnold,\n  title={Kolmogorov-Arnold Transformer},\n  author={Xingyi Yang, Xinchao Wang},\n  booktitle={The Thirteenth International Conference on Learning Representations},\n  year={2025},\n  url={https://openreview.net/forum?id=BCeock53nt}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fadamdad%2Fkat","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fadamdad%2Fkat","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fadamdad%2Fkat/lists"}