Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/adamdad/kat

Kolmogorov-Arnold Transformer: A PyTorch Implementation with CUDA kernel
https://github.com/adamdad/kat

computer-vision kan kolmogorov-arnold-networks transformer

Last synced: 5 days ago
JSON representation

Kolmogorov-Arnold Transformer: A PyTorch Implementation with CUDA kernel

Awesome Lists containing this project

README

        





Kolmogorov–Arnold Transformer:
A PyTorch Implementation







Tested PyTorch Versions
License




Yes, I kan!

πŸŽ‰ This is a PyTorch/GPU implementation of the paper **Kolmogorov–Arnold Transformer (KAT)**, which replace the MLP layers in transformer with KAN layers.

**Kolmogorov–Arnold Transformer**

πŸ“[[Paper](https://arxiv.org/abs/2409.10594)] >[[code](https://github.com/Adamdad/kat)] >[[CUDA kernel](https://github.com/Adamdad/rational_kat_cu)]

Xingyi Yang, Xinchao Wang

National University of Singapore

### πŸ”‘ Key Insight:

Vanilla ViT + KAN struggle to scale effectively. We introduce the KAT model, which integrates GR-KANs into transformers for large-scale training scenarios like ImageNet, achieving significant performance improvements.




### 🎯 Our Solutions:
1. **Base Function**: Replace B-spline to CUDA-implemented Rational.
2. **Group KAN**: Share weights among groups of edges for efficiency.
3. **Initialization**: Maintain activation magnitudes across layers.

### βœ… Updates
- [x] Release the KAT paper, CUDA implementation and IN-1k training code.
- [ ] KAT Detection and segmentation code.
- [ ] KAT on NLP tasks.

## πŸ› οΈ Installation and Dataset
Please find our CUDA implementation in [https://github.com/Adamdad/rational_kat_cu.git](https://github.com/Adamdad/rational_kat_cu.git).
```shell
# install torch and other things
pip install timm==1.0.3
pip install wandb # I personally use wandb for results visualizations
git clone https://github.com/Adamdad/rational_kat_cu.git
cd rational_kat_cu
pip install -e .
```

πŸ“¦ Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)

```
β”‚imagenet/
β”œβ”€β”€train/
β”‚ β”œβ”€β”€ n01440764
β”‚ β”‚ β”œβ”€β”€ n01440764_10026.JPEG
β”‚ β”‚ β”œβ”€β”€ n01440764_10027.JPEG
β”‚ β”‚ β”œβ”€β”€ ......
β”‚ β”œβ”€β”€ ......
β”œβ”€β”€val/
β”‚ β”œβ”€β”€ n01440764
β”‚ β”‚ β”œβ”€β”€ ILSVRC2012_val_00000293.JPEG
β”‚ β”‚ β”œβ”€β”€ ILSVRC2012_val_00002138.JPEG
β”‚ β”‚ β”œβ”€β”€ ......
β”‚ β”œβ”€β”€ ......
```

## Usage

Refer to `example.py` for a detailed use case demonstrating how to use KAT with timm to classify an image.

## πŸ“Š Model Checkpoints
Download pre-trained models or access training checkpoints:

|🏷️ Model |βš™οΈ Setup |πŸ“¦ Param| πŸ“ˆ Top1 |πŸ”— Link|
| ---|---|---| ---|---|
|KAT-T| From Scratch|5.7M | 74.6| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch16_224_32487885cf13d2c14e461c9016fac8ad43f7c769171f132530941e930aeb5fe2.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224)
|KAT-T | From ViT | 5.7M | 75.7| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224-finetune_64f124d003803e4a7e1aba1ba23500ace359b544e8a5f0110993f25052e402fb.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224.vitft)
|KAT-S| From Scratch| 22.1M | 81.2| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224)
|KAT-S | From ViT |22.1M | 82.0| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch_224-finetune_3ae087a4c28e2993468eb377d5151350c52c80b2a70cc48ceec63d1328ba58e0.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224.vitft)
| KAT-B| From Scratch |86.6M| 82.3 | [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_base_patch16_224_abff874d925d756d15cde97303f772a3460ddbd44b9c53fb9ce5cf15be230fb6.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224)
| KAT-B | From ViT |86.6M| 82.8 | [link](https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224-finetune_440bf1ead9dd8ecab642078cfb60ae542f1fa33ca65517260501e02c011e38f2.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224.vitft)|

## πŸŽ“Model Training

All training scripts are under `scripts/`
```shell
bash scripts/train_kat_tiny_8x128.sh
```

If you want to change the hyper-parameters, can edit
```shell
#!/bin/bash
DATA_PATH=/local_home/dataset/imagenet/

bash ./dist_train.sh 8 $DATA_PATH \
--model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions
-b 128 \
--opt adamw \
--lr 1e-3 \
--weight-decay 0.05 \
--epochs 300 \
--mixup 0.8 \
--cutmix 1.0 \
--sched cosine \
--smoothing 0.1 \
--drop-path 0.1 \
--aa rand-m9-mstd0.5 \
--remode pixel --reprob 0.25 \
--amp \
--crop-pct 0.875 \
--mean 0.485 0.456 0.406 \
--std 0.229 0.224 0.225 \
--model-ema \
--model-ema-decay 0.9999 \
--output output/kat_tiny_swish_patch16_224 \
--log-wandb
```

## πŸ§ͺ Evaluation
To evaluate our `kat_tiny_patch16_224` models, run:

```shell
DATA_PATH=/local_home/dataset/imagenet/
CHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth
python validate.py $DATA_PATH --model kat_tiny_patch16_224 \
--checkpoint $CHECKPOINT_PATH -b 512

###################
Validating in float32. AMP not enabled.
Loaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'
Model kat_tiny_patch16_224 created, param count: 5718328
Data processing configuration for current model + dataset:
input_size: (3, 224, 224)
interpolation: bicubic
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)
crop_pct: 0.875
crop_mode: center
Test: [ 0/98] Time: 3.453s (3.453s, 148.28/s) Loss: 0.6989 (0.6989) Acc@1: 84.375 ( 84.375) Acc@5: 96.875 ( 96.875)
.......
Test: [ 90/98] Time: 0.212s (0.592s, 864.23/s) Loss: 1.1640 (1.1143) Acc@1: 71.875 ( 74.270) Acc@5: 93.750 ( 92.220)
* Acc@1 74.558 (25.442) Acc@5 92.390 (7.610)
--result
{
"model": "kat_tiny_patch16_224",
"top1": 74.558,
"top1_err": 25.442,
"top5": 92.39,
"top5_err": 7.61,
"param_count": 5.72,
"img_size": 224,
"crop_pct": 0.875,
"interpolation": "bicubic"
}
```

## πŸ™ Acknowledgments
We extend our gratitude to the authors of [rational_activations](https://github.com/ml-research/rational_activations) for their contributions to CUDA rational function implementations that inspired parts of this work. We thank [@yuweihao](https://github.com/yuweihao), [@florinshen](https://github.com/florinshen), [@Huage001](https://github.com/Huage001) and [@yu-rp](https://github.com/yu-rp) for valuable discussions.

## πŸ“š Bibtex
If you use this repository, please cite:
```bibtex
@misc{yang2024kat,
title={Kolmogorov–Arnold Transformer},
author={Xingyi Yang and Xinchao Wang},
year={2024},
eprint={2409.10594},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```