https://github.com/ambidextrous9/transformer-from-scratch
Transformer from Scratch
https://github.com/ambidextrous9/transformer-from-scratch
attention decoder encoder encoder-decoder-model masked-language-models transformer
Last synced: 4 months ago
JSON representation
Transformer from Scratch
- Host: GitHub
- URL: https://github.com/ambidextrous9/transformer-from-scratch
- Owner: ambideXtrous9
- Created: 2025-09-14T13:17:09.000Z (5 months ago)
- Default Branch: main
- Last Pushed: 2025-09-21T07:48:18.000Z (5 months ago)
- Last Synced: 2025-09-21T09:25:35.201Z (5 months ago)
- Topics: attention, decoder, encoder, encoder-decoder-model, masked-language-models, transformer
- Language: Python
- Homepage:
- Size: 43.9 KB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# ๐ Transformer from Scratch
[](https://pytorch.org/)
[](https://pytorch-lightning.readthedocs.io/)
[](https://www.python.org/)
[](https://opensource.org/licenses/MIT)
**A complete, production-ready implementation of the Transformer architecture from "Attention Is All You Need"**
*Built with PyTorch Lightning for scalable training and inference*
---
## โจ What Makes This Special
๐ฏ **Complete Implementation** - Every component from the original paper, meticulously crafted
โก **Lightning Fast** - PyTorch Lightning integration for distributed training
๐ง **Production Ready** - Proper error handling, logging, and checkpointing
๐ง **Modular Design** - Each component is independently testable and reusable
๐งช **Independent Testing** - Run each module separately for debugging and learning
๐ **Educational** - Clean, well-documented code perfect for learning
๐จ **Modern Stack** - Uses GPT-2 tokenizer and state-of-the-art practices
๐ **Multiple Architectures** - CrossAttention, DecoderOnly, and MoE implementations
๐ **Comprehensive Metrics** - BLEU, ROUGE, METEOR, and BERTScore evaluation
๐๏ธ **Advanced Features** - Mixture of Experts with Top-K routing and sparse computation
---
## ๐๏ธ Architecture Deep Dive
### Core Components
| Component | Description | Key Features |
|-----------|-------------|--------------|
| **๐ค TokenEmbedding** | Converts tokens to dense vectors | Scaling, padding handling, vocabulary mapping |
| **๐ PositionalEmbedding** | Adds position information | Sinusoidal & learned encodings, flexible max positions |
| **๐ฏ MultiHeadSelfAttention** | The heart of Transformers | Causal masking, cross-attention, scaled dot-product |
| **๐ง PositionwiseFeedForward** | Non-linear transformations | GELU activation, configurable dimensions |
| **โ AddNorm** | Residual connections + normalization | Layer normalization, dropout, gradient flow |
| **๐ฅ Encoder** | Processes input sequences | Stacked layers, self-attention, context building |
| **๐ค Decoder** | Generates output sequences | Masked attention, cross-attention, autoregressive |
| **๐๏ธ MoE Components** | Mixture of Experts implementation | Top-K routing, sparse computation, expert specialization |
| **๐ TopKRouter** | Expert selection mechanism | Dynamic routing, load balancing, efficient computation |
### Model Architectures
| Architecture | Description | Use Cases | Key Features |
|--------------|-------------|-----------|--------------|
| **๐ CrossAttentionSeq2Seq** | Full encoder-decoder with cross-attention | Translation, summarization | Bidirectional encoding, cross-attention |
| **๐ DecoderOnly** | GPT-style autoregressive model | Text generation, completion | Causal masking, next-token prediction |
| **๐๏ธ DecoderOnlyMoE** | Decoder-only with Mixture of Experts | Large-scale text generation | Sparse activation, expert routing |
### Data Flow
```mermaid
graph TD
A[Input Text] --> B[Tokenization]
B --> C[Token Embedding]
C --> D[Positional Encoding]
D --> E[Encoder Stack]
E --> F[Context Vectors]
F --> G[Decoder Stack]
G --> H[Output Logits]
H --> I[Generated Text]
```
---
## ๐ Quick Start
### 1. Installation
```bash
# Clone the repository
git clone https://github.com/yourusername/transformer-from-scratch.git
cd transformer-from-scratch
# Install dependencies
pip install torch pytorch-lightning transformers pandas numpy sacrebleu rouge_score bert_score nltk
```
### 2. Training
Choose from multiple model architectures:
#### CrossAttention Seq2Seq Model
```bash
# Train encoder-decoder with cross-attention
python Trainer.py
```
#### Decoder-Only Model (GPT-style)
```bash
# Train decoder-only autoregressive model
python DecoderOnlyTrainer.py
```
#### Decoder-Only with Mixture of Experts
```bash
# Train MoE model with expert routing
python DecoderMoETrainer.py
```
**Training Features:**
- ๐ฏ **Automatic checkpointing** - Best model saved automatically
- ๐ **Real-time monitoring** - Loss tracking and validation metrics
- ๐ **GPU acceleration** - GPU support
- ๐ **Progress tracking** - Detailed logging and progress bars
- ๐๏ธ **MoE Support** - Sparse computation with expert routing
- ๐ **Comprehensive Metrics** - BLEU, ROUGE, METEOR, BERTScore evaluation
### 3. Inference
Choose the appropriate inference script for your model:
#### CrossAttention Seq2Seq Model
```bash
# Generate text completions with encoder-decoder
python Inference.py
```
#### Decoder-Only Model
```bash
# Generate text with decoder-only model
python DecoderOnlyInference.py
```
#### Decoder-Only with MoE
```bash
# Generate text with MoE model
python DecoderMoEInference.py
```
**Inference Features:**
- ๐ฒ **Greedy decoding** - Deterministic text generation
- โก **Fast inference** - Optimized for production use
- ๐ฏ **Flexible input** - Handle variable length sequences
- ๐ง **Easy integration** - Simple API for your applications
- ๐๏ธ **MoE Support** - Efficient expert routing during inference
- ๐ **Multiple Models** - Support for different architectures
### 4. Independent Module Testing
Each component can be run independently for testing and experimentation:
```bash
# Test individual components
python Embedding.py # Test token & positional embeddings
python MultiHeadSelfAttention.py # Test attention mechanism
python FFN.py # Test feed-forward network
python AddNorm.py # Test residual connections & normalization
python Encoder.py # Test encoder stack
python Decoder.py # Test decoder stack
python Seq2SeqModel.py # Test complete model
```
**Independent Testing Features:**
- ๐งช **Component isolation** - Test each part separately
- ๐ **Debugging friendly** - Easy to identify issues in specific components
- ๐ **Learning focused** - Understand each component's behavior individually
- โก **Quick validation** - Fast testing without full training pipeline
---
## ๐ Evaluation Metrics
The codebase includes comprehensive evaluation metrics for assessing model performance:
### Automatic Metrics
| Metric | Description | Range | Use Case |
|--------|-------------|-------|----------|
| **๐ฏ BLEU** | N-gram overlap with reference | 0-100 | Translation quality, text similarity |
| **๐ ROUGE-1** | Unigram overlap | 0-1 | Content coverage, summarization |
| **๐ ROUGE-2** | Bigram overlap | 0-1 | Phrase-level similarity |
| **๐ ROUGE-L** | Longest common subsequence | 0-1 | Structural similarity |
| **โ๏ธ METEOR** | Semantic similarity with synonyms | 0-1 | Meaning preservation |
| **๐ง BERTScore** | Contextual embedding similarity | 0-1 | Semantic understanding |
### Implementation Features
- **๐ Real-time Tracking** - Metrics computed during validation
- **๐ Progress Monitoring** - All metrics logged to PyTorch Lightning
- **๐ Automatic Evaluation** - No manual intervention required
- **โก Efficient Computation** - Optimized for large-scale evaluation
- **๐ Comprehensive Coverage** - Multiple evaluation perspectives
### Usage
All metrics are automatically computed during training validation steps and logged to the progress bar and tensorboard logs.
---
## ๐ Dataset & Task
**Versatile Text Completion Dataset**
- ๐ **2,000 examples** of diverse text completion pairs
- ๐ฏ **Task**: Complete partial sentences with meaningful continuations
- ๐ **Format**: `"partial sentence..." โ "completion text"`
- ๐ **Train/Val Split**: 80/20 automatic split
- ๐ **Diverse Topics**: Covers multiple domains and contexts
**Example:**
```
Input: "The rise of renewable energy is changing global markets and Experts predict this shift will redefine economies"
Output: "reducing dependence on fossil fuels and lowering emissions."
```
**Dataset Features:**
- ๐ **Educational Content** - Science, technology, and general knowledge
- ๐ **Multiple Formats** - Various sentence structures and completion types
- ๐ฏ **Quality Controlled** - Curated for meaningful learning objectives
- ๐ **Balanced Distribution** - Even representation across different topics
---
## โ๏ธ Configuration
### Model Architecture
| Parameter | Default | Description |
|-----------|---------|-------------|
| `d_model` | 256 | Model dimension (embedding size) |
| `num_heads` | 4-8 | Number of attention heads |
| `num_encoder_layers` | 2-6 | Encoder stack depth |
| `num_decoder_layers` | 2-6 | Decoder stack depth |
| `d_ff` | 128-1024 | Feed-forward dimension |
| `dropout` | 0.1 | Dropout rate |
| `max_positions` | 32-512 | Maximum sequence length |
| `use_sinusoidal_pos` | True | Use sinusoidal positional encoding |
### MoE Configuration (DecoderOnlyMoE)
| Parameter | Default | Description |
|-----------|---------|-------------|
| `num_experts` | 4 | Number of expert networks |
| `top_k` | 2 | Number of experts to activate per token |
| `expert_capacity` | Auto | Maximum tokens per expert |
### Training Configuration
| Parameter | Value | Description |
|-----------|-------|-------------|
| `batch_size` | 4 | Training batch size |
| `learning_rate` | 1e-3 | Adam optimizer learning rate |
| `max_epochs` | 100 | Maximum training epochs |
| `gradient_clip` | 1.0 | Gradient clipping threshold |
| `checkpoint_monitor` | val_loss_epoch | Model selection metric |
---
## ๐ Project Structure
```
transformer-from-scratch/
โโโ ๐ง Core Components
โ โโโ Embedding.py # Token & positional embeddings
โ โโโ MultiHeadSelfAttention.py # Multi-head attention mechanism
โ โโโ FFN.py # Position-wise feed-forward
โ โโโ AddNorm.py # Residual connections + normalization
โโโ ๐๏ธ Architecture Models
โ โโโ Encoder.py # Encoder stack implementation
โ โโโ Decoder.py # Decoder stack implementation
โ โโโ CrossAttentionSeq2SeqModel.py # Full encoder-decoder model
โ โโโ DecoderOnlySeq2SeqModel.py # GPT-style decoder-only model
โ โโโ DecoderMoE.py # Decoder-only with Mixture of Experts
โโโ ๐ Training Scripts
โ โโโ Trainer.py # CrossAttention training pipeline
โ โโโ DecoderOnlyTrainer.py # Decoder-only training pipeline
โ โโโ DecoderMoETrainer.py # MoE training pipeline
โโโ ๐ฏ Inference Scripts
โ โโโ Inference.py # CrossAttention inference
โ โโโ DecoderOnlyInference.py # Decoder-only inference
โ โโโ DecoderMoEInference.py # MoE inference
โโโ ๐ Data
โ โโโ versatile_dataset_2000.csv # Main training dataset
โ โโโ synthetic_text_completion.csv # Legacy dataset
โโโ ๐ Checkpoints
โ โโโ Seq2SeqCheckpoints/ # CrossAttention model checkpoints
โ โโโ DecoderOnlyCheckpoints/ # Decoder-only model checkpoints
โ โโโ DecoderMoECheckpoints/ # MoE model checkpoints
โโโ ๐ Logs
โโโ lightning_logs/ # Training logs and metrics
```
---
## ๐ฏ Use Cases
### Perfect For:
- ๐ **Learning** - Understanding Transformer architecture
- ๐ฌ **Research** - Experimenting with attention mechanisms
- ๐ **Prototyping** - Quick seq2seq model development
- ๐งช **Component Testing** - Debug and validate individual modules
### Applications:
#### CrossAttention Seq2Seq Model
- ๐ **Summarization** - Generate concise summaries
- ๐ **Translation** - Sequence-to-sequence translation
- ๐ **Question Answering** - Context-aware responses
- ๐ **Data-to-Text** - Convert structured data to natural language
#### Decoder-Only Models
- ๐ **Text Completion** - Auto-complete sentences
- ๐ฌ **Chatbots** - Conversational AI systems
- ๐จ **Creative Writing** - Story and content generation
- ๐ **Code Generation** - Programming assistance
#### MoE Models
- ๐ **Large-Scale Generation** - Efficient text generation at scale
- ๐ฏ **Specialized Tasks** - Expert routing for domain-specific content
- โก **Resource Optimization** - Sparse computation for better efficiency
- ๐ง **Multi-Domain Learning** - Handle diverse topics with specialized experts
---
## ๐๏ธ Mixture of Experts (MoE) Implementation
### Key Features
The MoE implementation includes several advanced features for efficient sparse computation:
#### Expert Architecture
- **๐ง ExpertMLP** - Individual expert networks with GELU activation
- **๐ฏ TopKRouter** - Intelligent routing mechanism for expert selection
- **โก Sparse Computation** - Only activate selected experts per token
- **๐ Load Balancing** - Automatic expert capacity management
#### Routing Strategy
- **๐ฒ Softmax Gating** - Probabilistic expert selection
- **๐ Top-K Selection** - Activate only the most relevant experts
- **๐ Dynamic Routing** - Adaptive expert selection based on input
- **โ๏ธ Load Balancing** - Prevent expert overloading
#### Performance Optimizations
- **๐ Sparse Activation** - Reduce computational overhead
- **๐พ Memory Efficient** - Only store active expert outputs
- **๐ Batch Processing** - Efficient parallel expert computation
- **๐ Gradient Flow** - Proper backpropagation through routing
### Usage Example
```python
# Initialize MoE model
model = DecoderOnlyMoEModel(
vocab_size=vocab_size,
d_model=256,
num_experts=4, # Number of expert networks
top_k=2, # Activate top 2 experts per token
num_layers=6,
tokenizer=tokenizer
)
# Training automatically handles expert routing
trainer.fit(model, train_loader, val_loader)
```
---
## ๐ค Contributing
We welcome contributions! Here's how you can help:
1. ๐ด **Fork** the repository
2. ๐ **Create** a feature branch (`git checkout -b feature/AmazingFeature`)
3. ๐พ **Commit** your changes (`git commit -m 'Add AmazingFeature'`)
4. ๐ค **Push** to the branch (`git push origin feature/AmazingFeature`)
5. ๐ **Open** a Pull Request
### Areas for Contribution:
- ๐ **Performance optimizations**
- ๐งช **Additional attention mechanisms**
- ๐ **More datasets and tasks**
- ๐ **Documentation improvements**
- ๐ **Bug fixes and testing**
---
## ๐ References & Learning
### Papers
1. **Vaswani, A., et al.** (2017). "Attention is all you need." *NeurIPS 2017*
2. **Devlin, J., et al.** (2018). "BERT: Pre-training of Deep Bidirectional Transformers." *NAACL 2019*
### Resources
- ๐ [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)
- โก [PyTorch Lightning Documentation](https://pytorch-lightning.readthedocs.io/)
- ๐ [Attention Mechanism Explained](https://distill.pub/2016/augmented-rnns/)
- ๐ฅ [Transformer from Scratch](https://www.youtube.com/watch?v=ISNdQcPhsts)
---
**โญ Star this repository if you found it helpful!**
Made with โค๏ธ and lots of โ
[Report Bug](https://github.com/yourusername/transformer-from-scratch/issues) ยท [Request Feature](https://github.com/yourusername/transformer-from-scratch/issues) ยท [Documentation](https://github.com/yourusername/transformer-from-scratch/wiki)