An open API service indexing awesome lists of open source software.

https://github.com/codewithdark-git/hybrid-vision-transformer-hvt

Hybrid Vision Transformer (HVT) for Image Classification
https://github.com/codewithdark-git/hybrid-vision-transformer-hvt

cnn encoder encoder-decoder-architecture paper python python3 reasearch transformer

Last synced: 3 months ago
JSON representation

Hybrid Vision Transformer (HVT) for Image Classification

Awesome Lists containing this project

README

        

# Hybrid Vision Transformer (HVT)

[![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-ee4c2c?logo=pytorch)](https://pytorch.org/get-started/locally/)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
[![arXiv](https://img.shields.io/badge/arXiv-2023.xxxxx-b31b1b.svg)](https://arxiv.org/abs/)

This repository contains the official implementation of "Hybrid Vision Transformer (HVT) for Image Classification" paper.

## Overview

HVT is a novel architecture that combines the strengths of CNNs and Vision Transformers for efficient and accurate image classification. It achieves state-of-the-art performance of 95.74% accuracy on CIFAR-10.

![HVT Architecture](images/Hybrid%20Vision%20Transformer%20(HVT)%20Architecture%20(1).jpg)

## Key Features

- 🚀 State-of-the-art accuracy (95.74% on CIFAR-10)
- 🔄 Hybrid CNN-Transformer architecture
- 💡 Novel Informative Feature Fusion module
- 📊 Comprehensive evaluation and benchmarks
- 🛠️ Easy-to-use training pipeline

## Installation

```bash
# Clone the repository
git clone https://github.com/codewithdark-git/Hybrid-Vision-Transformer-HVT-.git
cd Hybrid-Vision-Transformer-HVT-

# Install dependencies
pip install -r requirements.txt
```

## Quick Start

```python
import torch

class HybridVisionTransformer(nn.Module):
pass

# Initialize model
model = HybridVisionTransformer(
num_classes=10,
backbone='tf_efficientnetv2_s',
embed_dim=510,
num_heads=10,
transformer_layers=8
)

```

## Model Architecture

The HVT consists of four main components:

1. **CNN Backbone (EfficientNetV2)**
- Extracts hierarchical features
- Pretrained on ImageNet

2. **Feature Projectors**
```python
feature_projectors = nn.ModuleList([
nn.Sequential(
nn.Conv2d(channels, embed_dim, 1),
nn.BatchNorm2d(embed_dim),
nn.GELU()
)
])
```

3. **Transformer Encoder**
- 8 layers with 10 attention heads
- GELU activation and LayerNorm

4. **Informative Feature Fusion**
- Cross-modal attention
- Adaptive feature gating
- Feature refinement

## Results

| Model |Dataset| Test Accuracy | Test Loss |
|-------|--------|--------------|-----------|
| ResNet-50 |CIFAR-10| 85.13% | 0.4141 |
| HVT (Ours) |CIFAR-10| 90.74% | 0.3229 |
| ResNet-50 |SVHN| 95.98% | 0.1671 |
| HVT (Ours) |SVHN| 96.08% | 0.1675 |

## Training Configuration

```python
# Hyperparameters
config = {
'embed_dim': 510,
'num_heads': 10,
'transformer_layers': 8,
'dropout': 0.3,
'batch_size': 32,
'learning_rate': 1e-4,
'epochs': 10
}
```

## Requirements

- Python 3.8+
- PyTorch 1.8+
- timm
- numpy
- torchvision
- CUDA-capable GPU (8GB+ VRAM recommended)

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## Acknowledgments

- EfficientNet implementation from [timm](https://github.com/rwightman/pytorch-image-models)
- Vision Transformer implementation inspired by [ViT](https://github.com/google-research/vision_transformer)

## Contact

- Ahsan Umar - [LinkedIn](https://www.linkedin.com/in/codewithdark)
- Email: [email protected]