https://github.com/codewithdark-git/hybrid-vision-transformer-hvt
Hybrid Vision Transformer (HVT) for Image Classification
https://github.com/codewithdark-git/hybrid-vision-transformer-hvt
cnn encoder encoder-decoder-architecture paper python python3 reasearch transformer
Last synced: 3 months ago
JSON representation
Hybrid Vision Transformer (HVT) for Image Classification
- Host: GitHub
- URL: https://github.com/codewithdark-git/hybrid-vision-transformer-hvt
- Owner: codewithdark-git
- License: mit
- Created: 2025-01-24T07:22:52.000Z (5 months ago)
- Default Branch: main
- Last Pushed: 2025-01-30T13:33:17.000Z (5 months ago)
- Last Synced: 2025-01-30T14:31:36.943Z (5 months ago)
- Topics: cnn, encoder, encoder-decoder-architecture, paper, python, python3, reasearch, transformer
- Language: Jupyter Notebook
- Homepage:
- Size: 2.3 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Hybrid Vision Transformer (HVT)
[](https://pytorch.org/get-started/locally/)
[](LICENSE)
[](https://arxiv.org/abs/)This repository contains the official implementation of "Hybrid Vision Transformer (HVT) for Image Classification" paper.
## Overview
HVT is a novel architecture that combines the strengths of CNNs and Vision Transformers for efficient and accurate image classification. It achieves state-of-the-art performance of 95.74% accuracy on CIFAR-10.
%20Architecture%20(1).jpg)
## Key Features
- 🚀 State-of-the-art accuracy (95.74% on CIFAR-10)
- 🔄 Hybrid CNN-Transformer architecture
- 💡 Novel Informative Feature Fusion module
- 📊 Comprehensive evaluation and benchmarks
- 🛠️ Easy-to-use training pipeline## Installation
```bash
# Clone the repository
git clone https://github.com/codewithdark-git/Hybrid-Vision-Transformer-HVT-.git
cd Hybrid-Vision-Transformer-HVT-# Install dependencies
pip install -r requirements.txt
```## Quick Start
```python
import torchclass HybridVisionTransformer(nn.Module):
pass# Initialize model
model = HybridVisionTransformer(
num_classes=10,
backbone='tf_efficientnetv2_s',
embed_dim=510,
num_heads=10,
transformer_layers=8
)```
## Model Architecture
The HVT consists of four main components:
1. **CNN Backbone (EfficientNetV2)**
- Extracts hierarchical features
- Pretrained on ImageNet2. **Feature Projectors**
```python
feature_projectors = nn.ModuleList([
nn.Sequential(
nn.Conv2d(channels, embed_dim, 1),
nn.BatchNorm2d(embed_dim),
nn.GELU()
)
])
```3. **Transformer Encoder**
- 8 layers with 10 attention heads
- GELU activation and LayerNorm4. **Informative Feature Fusion**
- Cross-modal attention
- Adaptive feature gating
- Feature refinement## Results
| Model |Dataset| Test Accuracy | Test Loss |
|-------|--------|--------------|-----------|
| ResNet-50 |CIFAR-10| 85.13% | 0.4141 |
| HVT (Ours) |CIFAR-10| 90.74% | 0.3229 |
| ResNet-50 |SVHN| 95.98% | 0.1671 |
| HVT (Ours) |SVHN| 96.08% | 0.1675 |## Training Configuration
```python
# Hyperparameters
config = {
'embed_dim': 510,
'num_heads': 10,
'transformer_layers': 8,
'dropout': 0.3,
'batch_size': 32,
'learning_rate': 1e-4,
'epochs': 10
}
```## Requirements
- Python 3.8+
- PyTorch 1.8+
- timm
- numpy
- torchvision
- CUDA-capable GPU (8GB+ VRAM recommended)## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Acknowledgments
- EfficientNet implementation from [timm](https://github.com/rwightman/pytorch-image-models)
- Vision Transformer implementation inspired by [ViT](https://github.com/google-research/vision_transformer)## Contact
- Ahsan Umar - [LinkedIn](https://www.linkedin.com/in/codewithdark)
- Email: [email protected]