An open API service indexing awesome lists of open source software.

https://github.com/ambidextrous9/transformer-from-scratch

Transformer from Scratch
https://github.com/ambidextrous9/transformer-from-scratch

attention decoder encoder encoder-decoder-model masked-language-models transformer

Last synced: 4 months ago
JSON representation

Transformer from Scratch

Awesome Lists containing this project

README

          

# ๐Ÿš€ Transformer from Scratch

[![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white)](https://pytorch.org/)
[![Lightning](https://img.shields.io/badge/Lightning-792EE5?style=for-the-badge&logo=pytorchlightning&logoColor=white)](https://pytorch-lightning.readthedocs.io/)
[![Python](https://img.shields.io/badge/Python-3.11+-3776AB?style=for-the-badge&logo=python&logoColor=white)](https://www.python.org/)
[![License](https://img.shields.io/badge/License-MIT-yellow.svg?style=for-the-badge)](https://opensource.org/licenses/MIT)

**A complete, production-ready implementation of the Transformer architecture from "Attention Is All You Need"**

*Built with PyTorch Lightning for scalable training and inference*

---

## โœจ What Makes This Special

๐ŸŽฏ **Complete Implementation** - Every component from the original paper, meticulously crafted
โšก **Lightning Fast** - PyTorch Lightning integration for distributed training
๐Ÿง  **Production Ready** - Proper error handling, logging, and checkpointing
๐Ÿ”ง **Modular Design** - Each component is independently testable and reusable
๐Ÿงช **Independent Testing** - Run each module separately for debugging and learning
๐Ÿ“š **Educational** - Clean, well-documented code perfect for learning
๐ŸŽจ **Modern Stack** - Uses GPT-2 tokenizer and state-of-the-art practices
๐Ÿš€ **Multiple Architectures** - CrossAttention, DecoderOnly, and MoE implementations
๐Ÿ“Š **Comprehensive Metrics** - BLEU, ROUGE, METEOR, and BERTScore evaluation
๐ŸŽ›๏ธ **Advanced Features** - Mixture of Experts with Top-K routing and sparse computation

---

## ๐Ÿ—๏ธ Architecture Deep Dive

### Core Components

| Component | Description | Key Features |
|-----------|-------------|--------------|
| **๐Ÿ”ค TokenEmbedding** | Converts tokens to dense vectors | Scaling, padding handling, vocabulary mapping |
| **๐Ÿ“ PositionalEmbedding** | Adds position information | Sinusoidal & learned encodings, flexible max positions |
| **๐ŸŽฏ MultiHeadSelfAttention** | The heart of Transformers | Causal masking, cross-attention, scaled dot-product |
| **๐Ÿง  PositionwiseFeedForward** | Non-linear transformations | GELU activation, configurable dimensions |
| **โž• AddNorm** | Residual connections + normalization | Layer normalization, dropout, gradient flow |
| **๐Ÿ“ฅ Encoder** | Processes input sequences | Stacked layers, self-attention, context building |
| **๐Ÿ“ค Decoder** | Generates output sequences | Masked attention, cross-attention, autoregressive |
| **๐ŸŽ›๏ธ MoE Components** | Mixture of Experts implementation | Top-K routing, sparse computation, expert specialization |
| **๐Ÿ”€ TopKRouter** | Expert selection mechanism | Dynamic routing, load balancing, efficient computation |

### Model Architectures

| Architecture | Description | Use Cases | Key Features |
|--------------|-------------|-----------|--------------|
| **๐Ÿ”„ CrossAttentionSeq2Seq** | Full encoder-decoder with cross-attention | Translation, summarization | Bidirectional encoding, cross-attention |
| **๐Ÿ“ DecoderOnly** | GPT-style autoregressive model | Text generation, completion | Causal masking, next-token prediction |
| **๐ŸŽ›๏ธ DecoderOnlyMoE** | Decoder-only with Mixture of Experts | Large-scale text generation | Sparse activation, expert routing |

### Data Flow

```mermaid
graph TD
A[Input Text] --> B[Tokenization]
B --> C[Token Embedding]
C --> D[Positional Encoding]
D --> E[Encoder Stack]
E --> F[Context Vectors]
F --> G[Decoder Stack]
G --> H[Output Logits]
H --> I[Generated Text]
```

---

## ๐Ÿš€ Quick Start

### 1. Installation

```bash
# Clone the repository
git clone https://github.com/yourusername/transformer-from-scratch.git
cd transformer-from-scratch

# Install dependencies
pip install torch pytorch-lightning transformers pandas numpy sacrebleu rouge_score bert_score nltk
```

### 2. Training

Choose from multiple model architectures:

#### CrossAttention Seq2Seq Model
```bash
# Train encoder-decoder with cross-attention
python Trainer.py
```

#### Decoder-Only Model (GPT-style)
```bash
# Train decoder-only autoregressive model
python DecoderOnlyTrainer.py
```

#### Decoder-Only with Mixture of Experts
```bash
# Train MoE model with expert routing
python DecoderMoETrainer.py
```

**Training Features:**
- ๐ŸŽฏ **Automatic checkpointing** - Best model saved automatically
- ๐Ÿ“Š **Real-time monitoring** - Loss tracking and validation metrics
- ๐Ÿ”„ **GPU acceleration** - GPU support
- ๐Ÿ“ˆ **Progress tracking** - Detailed logging and progress bars
- ๐ŸŽ›๏ธ **MoE Support** - Sparse computation with expert routing
- ๐Ÿ“Š **Comprehensive Metrics** - BLEU, ROUGE, METEOR, BERTScore evaluation

### 3. Inference

Choose the appropriate inference script for your model:

#### CrossAttention Seq2Seq Model
```bash
# Generate text completions with encoder-decoder
python Inference.py
```

#### Decoder-Only Model
```bash
# Generate text with decoder-only model
python DecoderOnlyInference.py
```

#### Decoder-Only with MoE
```bash
# Generate text with MoE model
python DecoderMoEInference.py
```

**Inference Features:**
- ๐ŸŽฒ **Greedy decoding** - Deterministic text generation
- โšก **Fast inference** - Optimized for production use
- ๐ŸŽฏ **Flexible input** - Handle variable length sequences
- ๐Ÿ”ง **Easy integration** - Simple API for your applications
- ๐ŸŽ›๏ธ **MoE Support** - Efficient expert routing during inference
- ๐Ÿ“Š **Multiple Models** - Support for different architectures

### 4. Independent Module Testing

Each component can be run independently for testing and experimentation:

```bash
# Test individual components
python Embedding.py # Test token & positional embeddings
python MultiHeadSelfAttention.py # Test attention mechanism
python FFN.py # Test feed-forward network
python AddNorm.py # Test residual connections & normalization
python Encoder.py # Test encoder stack
python Decoder.py # Test decoder stack
python Seq2SeqModel.py # Test complete model
```

**Independent Testing Features:**
- ๐Ÿงช **Component isolation** - Test each part separately
- ๐Ÿ” **Debugging friendly** - Easy to identify issues in specific components
- ๐Ÿ“š **Learning focused** - Understand each component's behavior individually
- โšก **Quick validation** - Fast testing without full training pipeline

---

## ๐Ÿ“Š Evaluation Metrics

The codebase includes comprehensive evaluation metrics for assessing model performance:

### Automatic Metrics

| Metric | Description | Range | Use Case |
|--------|-------------|-------|----------|
| **๐ŸŽฏ BLEU** | N-gram overlap with reference | 0-100 | Translation quality, text similarity |
| **๐Ÿ“ ROUGE-1** | Unigram overlap | 0-1 | Content coverage, summarization |
| **๐Ÿ“ ROUGE-2** | Bigram overlap | 0-1 | Phrase-level similarity |
| **๐Ÿ“ ROUGE-L** | Longest common subsequence | 0-1 | Structural similarity |
| **โ˜„๏ธ METEOR** | Semantic similarity with synonyms | 0-1 | Meaning preservation |
| **๐Ÿง  BERTScore** | Contextual embedding similarity | 0-1 | Semantic understanding |

### Implementation Features

- **๐Ÿ“Š Real-time Tracking** - Metrics computed during validation
- **๐Ÿ“ˆ Progress Monitoring** - All metrics logged to PyTorch Lightning
- **๐Ÿ”„ Automatic Evaluation** - No manual intervention required
- **โšก Efficient Computation** - Optimized for large-scale evaluation
- **๐Ÿ“‹ Comprehensive Coverage** - Multiple evaluation perspectives

### Usage

All metrics are automatically computed during training validation steps and logged to the progress bar and tensorboard logs.

---

## ๐Ÿ“Š Dataset & Task

**Versatile Text Completion Dataset**
- ๐Ÿ“ **2,000 examples** of diverse text completion pairs
- ๐ŸŽฏ **Task**: Complete partial sentences with meaningful continuations
- ๐Ÿ“ **Format**: `"partial sentence..." โ†’ "completion text"`
- ๐Ÿ”„ **Train/Val Split**: 80/20 automatic split
- ๐ŸŒ **Diverse Topics**: Covers multiple domains and contexts

**Example:**
```
Input: "The rise of renewable energy is changing global markets and Experts predict this shift will redefine economies"
Output: "reducing dependence on fossil fuels and lowering emissions."
```

**Dataset Features:**
- ๐Ÿ“š **Educational Content** - Science, technology, and general knowledge
- ๐Ÿ”„ **Multiple Formats** - Various sentence structures and completion types
- ๐ŸŽฏ **Quality Controlled** - Curated for meaningful learning objectives
- ๐Ÿ“Š **Balanced Distribution** - Even representation across different topics

---

## โš™๏ธ Configuration

### Model Architecture

| Parameter | Default | Description |
|-----------|---------|-------------|
| `d_model` | 256 | Model dimension (embedding size) |
| `num_heads` | 4-8 | Number of attention heads |
| `num_encoder_layers` | 2-6 | Encoder stack depth |
| `num_decoder_layers` | 2-6 | Decoder stack depth |
| `d_ff` | 128-1024 | Feed-forward dimension |
| `dropout` | 0.1 | Dropout rate |
| `max_positions` | 32-512 | Maximum sequence length |
| `use_sinusoidal_pos` | True | Use sinusoidal positional encoding |

### MoE Configuration (DecoderOnlyMoE)

| Parameter | Default | Description |
|-----------|---------|-------------|
| `num_experts` | 4 | Number of expert networks |
| `top_k` | 2 | Number of experts to activate per token |
| `expert_capacity` | Auto | Maximum tokens per expert |

### Training Configuration

| Parameter | Value | Description |
|-----------|-------|-------------|
| `batch_size` | 4 | Training batch size |
| `learning_rate` | 1e-3 | Adam optimizer learning rate |
| `max_epochs` | 100 | Maximum training epochs |
| `gradient_clip` | 1.0 | Gradient clipping threshold |
| `checkpoint_monitor` | val_loss_epoch | Model selection metric |

---

## ๐Ÿ“ Project Structure

```
transformer-from-scratch/
โ”œโ”€โ”€ ๐Ÿง  Core Components
โ”‚ โ”œโ”€โ”€ Embedding.py # Token & positional embeddings
โ”‚ โ”œโ”€โ”€ MultiHeadSelfAttention.py # Multi-head attention mechanism
โ”‚ โ”œโ”€โ”€ FFN.py # Position-wise feed-forward
โ”‚ โ””โ”€โ”€ AddNorm.py # Residual connections + normalization
โ”œโ”€โ”€ ๐Ÿ—๏ธ Architecture Models
โ”‚ โ”œโ”€โ”€ Encoder.py # Encoder stack implementation
โ”‚ โ”œโ”€โ”€ Decoder.py # Decoder stack implementation
โ”‚ โ”œโ”€โ”€ CrossAttentionSeq2SeqModel.py # Full encoder-decoder model
โ”‚ โ”œโ”€โ”€ DecoderOnlySeq2SeqModel.py # GPT-style decoder-only model
โ”‚ โ””โ”€โ”€ DecoderMoE.py # Decoder-only with Mixture of Experts
โ”œโ”€โ”€ ๐Ÿš€ Training Scripts
โ”‚ โ”œโ”€โ”€ Trainer.py # CrossAttention training pipeline
โ”‚ โ”œโ”€โ”€ DecoderOnlyTrainer.py # Decoder-only training pipeline
โ”‚ โ””โ”€โ”€ DecoderMoETrainer.py # MoE training pipeline
โ”œโ”€โ”€ ๐ŸŽฏ Inference Scripts
โ”‚ โ”œโ”€โ”€ Inference.py # CrossAttention inference
โ”‚ โ”œโ”€โ”€ DecoderOnlyInference.py # Decoder-only inference
โ”‚ โ””โ”€โ”€ DecoderMoEInference.py # MoE inference
โ”œโ”€โ”€ ๐Ÿ“Š Data
โ”‚ โ”œโ”€โ”€ versatile_dataset_2000.csv # Main training dataset
โ”‚ โ””โ”€โ”€ synthetic_text_completion.csv # Legacy dataset
โ”œโ”€โ”€ ๐Ÿ“ Checkpoints
โ”‚ โ”œโ”€โ”€ Seq2SeqCheckpoints/ # CrossAttention model checkpoints
โ”‚ โ”œโ”€โ”€ DecoderOnlyCheckpoints/ # Decoder-only model checkpoints
โ”‚ โ””โ”€โ”€ DecoderMoECheckpoints/ # MoE model checkpoints
โ””โ”€โ”€ ๐Ÿ“ˆ Logs
โ””โ”€โ”€ lightning_logs/ # Training logs and metrics
```

---

## ๐ŸŽฏ Use Cases

### Perfect For:
- ๐Ÿ“š **Learning** - Understanding Transformer architecture
- ๐Ÿ”ฌ **Research** - Experimenting with attention mechanisms
- ๐Ÿš€ **Prototyping** - Quick seq2seq model development
- ๐Ÿงช **Component Testing** - Debug and validate individual modules

### Applications:

#### CrossAttention Seq2Seq Model
- ๐Ÿ“„ **Summarization** - Generate concise summaries
- ๐Ÿ”„ **Translation** - Sequence-to-sequence translation
- ๐Ÿ“ **Question Answering** - Context-aware responses
- ๐Ÿ“Š **Data-to-Text** - Convert structured data to natural language

#### Decoder-Only Models
- ๐Ÿ“ **Text Completion** - Auto-complete sentences
- ๐Ÿ’ฌ **Chatbots** - Conversational AI systems
- ๐ŸŽจ **Creative Writing** - Story and content generation
- ๐Ÿ” **Code Generation** - Programming assistance

#### MoE Models
- ๐Ÿš€ **Large-Scale Generation** - Efficient text generation at scale
- ๐ŸŽฏ **Specialized Tasks** - Expert routing for domain-specific content
- โšก **Resource Optimization** - Sparse computation for better efficiency
- ๐Ÿง  **Multi-Domain Learning** - Handle diverse topics with specialized experts

---

## ๐ŸŽ›๏ธ Mixture of Experts (MoE) Implementation

### Key Features

The MoE implementation includes several advanced features for efficient sparse computation:

#### Expert Architecture
- **๐Ÿ”ง ExpertMLP** - Individual expert networks with GELU activation
- **๐ŸŽฏ TopKRouter** - Intelligent routing mechanism for expert selection
- **โšก Sparse Computation** - Only activate selected experts per token
- **๐Ÿ“Š Load Balancing** - Automatic expert capacity management

#### Routing Strategy
- **๐ŸŽฒ Softmax Gating** - Probabilistic expert selection
- **๐Ÿ” Top-K Selection** - Activate only the most relevant experts
- **๐Ÿ“ˆ Dynamic Routing** - Adaptive expert selection based on input
- **โš–๏ธ Load Balancing** - Prevent expert overloading

#### Performance Optimizations
- **๐Ÿš€ Sparse Activation** - Reduce computational overhead
- **๐Ÿ’พ Memory Efficient** - Only store active expert outputs
- **๐Ÿ”„ Batch Processing** - Efficient parallel expert computation
- **๐Ÿ“Š Gradient Flow** - Proper backpropagation through routing

### Usage Example

```python
# Initialize MoE model
model = DecoderOnlyMoEModel(
vocab_size=vocab_size,
d_model=256,
num_experts=4, # Number of expert networks
top_k=2, # Activate top 2 experts per token
num_layers=6,
tokenizer=tokenizer
)

# Training automatically handles expert routing
trainer.fit(model, train_loader, val_loader)
```

---

## ๐Ÿค Contributing

We welcome contributions! Here's how you can help:

1. ๐Ÿด **Fork** the repository
2. ๐ŸŒŸ **Create** a feature branch (`git checkout -b feature/AmazingFeature`)
3. ๐Ÿ’พ **Commit** your changes (`git commit -m 'Add AmazingFeature'`)
4. ๐Ÿ“ค **Push** to the branch (`git push origin feature/AmazingFeature`)
5. ๐Ÿ”„ **Open** a Pull Request

### Areas for Contribution:
- ๐Ÿš€ **Performance optimizations**
- ๐Ÿงช **Additional attention mechanisms**
- ๐Ÿ“Š **More datasets and tasks**
- ๐Ÿ“š **Documentation improvements**
- ๐Ÿ› **Bug fixes and testing**

---

## ๐Ÿ“š References & Learning

### Papers
1. **Vaswani, A., et al.** (2017). "Attention is all you need." *NeurIPS 2017*
2. **Devlin, J., et al.** (2018). "BERT: Pre-training of Deep Bidirectional Transformers." *NAACL 2019*

### Resources
- ๐Ÿ“– [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)
- โšก [PyTorch Lightning Documentation](https://pytorch-lightning.readthedocs.io/)
- ๐ŸŽ“ [Attention Mechanism Explained](https://distill.pub/2016/augmented-rnns/)
- ๐Ÿ”ฅ [Transformer from Scratch](https://www.youtube.com/watch?v=ISNdQcPhsts)

---

**โญ Star this repository if you found it helpful!**

Made with โค๏ธ and lots of โ˜•

[Report Bug](https://github.com/yourusername/transformer-from-scratch/issues) ยท [Request Feature](https://github.com/yourusername/transformer-from-scratch/issues) ยท [Documentation](https://github.com/yourusername/transformer-from-scratch/wiki)