An open API service indexing awesome lists of open source software.

https://github.com/junfanz1/minigpt-and-deepseek-mla-multi-head-latent-attention

An efficient and scalable attention module designed to reduce memory usage and improve inference speed in large language models. Designed and implemented the Multi-Head Latent Attention (MLA) module as a drop-in replacement for traditional multi-head attention (MHA) in large language models.
https://github.com/junfanz1/minigpt-and-deepseek-mla-multi-head-latent-attention

attention-mechanism deepseek llm mla multi-head-attention pytorch

Last synced: about 2 months ago
JSON representation

An efficient and scalable attention module designed to reduce memory usage and improve inference speed in large language models. Designed and implemented the Multi-Head Latent Attention (MLA) module as a drop-in replacement for traditional multi-head attention (MHA) in large language models.

Awesome Lists containing this project

README

        

# Part 1. DeepSeek Multi-Head Latent Attention (MLA) – LLM Attention Module Optimization

This document provides a detailed explanation of the DeepSeek multi-head latent attention (MLA) algorithm implementation. It covers the architecture and technology stack, explains the algorithm’s underlying principles, compares MLA with standard multi-head attention (MHA), and outlines key technical details useful for building large language model (LLM) systems.

- Technologies: Python, PyTorch, LLM internals, Low-Rank Approximation, Rotary Positional Embedding, Memory-Efficient Inference
- Designed and implemented a memory-efficient attention mechanism for long-context large language models, inspired by low-rank adaptation techniques.
- Implemented a novel multi-head latent attention (MLA) module in PyTorch, replacing standard multi-head attention (MHA) with low-rank compressed KV representations to significantly reduce inference memory footprint.
- Engineered a two-stage projection pipeline (down-projection + up-projection) for queries and key-values, enabling latent attention computation with minimal performance loss—an approach analogous to LoRA-style low-rank approximation.
- Developed a decoupled positional embedding mechanism, splitting representations into positional (RoPE) and non-positional (NoPE) components to prevent compression artifacts and improve embedding flexibility.
- Built and optimized the DeepSeekV2RotaryEmbedding module with precomputed sine/cosine caches, enabling fast and reusable rotary positional encoding during long-sequence inference.
- Designed a custom RMS normalization layer (DeepSeekV2RMSNorm) that improves numerical stability during training via a learnable scaling factor and per-feature variance computation.
- Achieved efficient attention computation by combining low-rank compression, RoPE decoupling, and causal masking, leading to significantly faster inference for long-context decoding.
- Benchmarked MLA against standard MHA, demonstrating comparable or superior model performance with a drastically smaller KV cache—a critical feature for scaling large LLMs on resource-constrained hardware.
- Ensured high modularity and extensibility of components (e.g., projection heads, rotary embeddings, normalization), enabling rapid experimentation with alternative compression strategies and positional encoding schemes.

---

## Table of Contents

- [Overview](#overview)
- [Architecture & Tech Stack](#architecture--tech-stack)
- [Algorithm Principles](#algorithm-principles)
- [Low-Rank Compression of Queries and KV States](#low-rank-compression-of-queries-and-kv-states)
- [Decoupled Rotary Positional Embedding](#decoupled-rotary-positional-embedding)
- [Attention Computation](#attention-computation)
- [MLA vs. Standard MHA](#mla-vs-standard-mha)
- [Additional Engineering Details](#additional-engineering-details)
- [Conclusion](#conclusion)

---

## Overview

DeepSeek’s MLA algorithm is an enhancement over traditional multi-head attention (MHA). Its key innovation is compressing the key–value (KV) representations into low-dimensional latent vectors. This compression reduces the memory footprint of the KV cache during inference while preserving (or even enhancing) model performance. In addition, MLA decouples positional encoding from the main semantic content by splitting the representations into “nope” (non-positional) and “rope” (rotary positional) parts.

An efficient and scalable attention module designed to reduce memory usage and improve inference speed in large language models.

Designed and implemented the Multi-Head Latent Attention (MLA) module as a drop-in replacement for traditional multi-head attention (MHA) in large language models. The project focused on improving inference efficiency and memory scalability without compromising model performance.

Key innovations include:

- Low-Rank Attention Compression: Applied LoRA-style down- and up-projection layers to compress and reconstruct query, key, and value states, drastically reducing KV cache size during long-context inference.
- Decoupled Positional Embedding: Separated rotary positional information (RoPE) from semantic content (NoPE), preserving attention quality under compression and enabling cleaner representation learning.
- Custom Rotary Embedding Module: Implemented DeepSeekV2RotaryEmbedding with precomputed sine/cosine caches, optimizing positional encoding for repeated use in long sequences.
- Memory-Efficient RMS Normalization: Developed DeepSeekV2RMSNorm, a customized normalization layer improving training stability through learnable variance scaling.
- Optimized Attention Flow: Engineered attention pipeline with causal masking, dropout regularization, and efficient tensor reshaping for scalable inference and autoregressive generation.
- Modular and Extensible Design: Built all components with modularity in mind, enabling adaptation to different model sizes, hardware constraints, and research directions.

This project showcases deep understanding of transformer internals, low-rank approximation, attention optimization techniques, and engineering best practices for production-ready LLM components.

---

## Architecture & Tech Stack

- **Framework:** PyTorch
- **Core Components:**
- **DeepSeekV2RMSNorm:** A variant of RMS normalization that scales hidden states with a learnable weight and adjusts variance for stability.
- **DeepSeekV2RotaryEmbedding:** Implements rotary positional embedding (RoPE) with caching for efficient reuse of sine and cosine embeddings.
- **MLA Module:** Combines projection layers for query and key–value compression, normalization, splitting into positional and non-positional components, and the standard attention computation.
- **Programming Languages & Libraries:**
- Python 3
- PyTorch (for tensor operations, neural network modules, and automatic differentiation)

---

## Algorithm Principles

### Low-Rank Compression of Queries and KV States

MLA introduces low-rank approximations to reduce the dimension of the full hidden state before computing attention. The main steps are:

1. **Down-Projection:**
- **Queries:** A linear layer (`q_down_proj`) compresses the input hidden state from the full model dimension to a lower-dimensional latent space.
- **Keys/Values:** A similar down-projection (`kv_down_proj`) compresses the input, followed by normalization (`kv_down_norm`).

2. **Up-Projection:**
- After compression, separate up-projection layers (`q_up_proj` for queries and `kv_up_proj` for keys/values) re-expand the latent representations into the dimensions required for multi-head attention.
- This two-stage process is similar in spirit to low-rank adaptation (LoRA) techniques.

3. **Splitting into Positional and Non-Positional Parts:**
- The query and key representations are split into:
- **NoPE part:** Remains unaffected by positional encoding.
- **RoPE part:** Undergoes rotary positional embedding.

### Decoupled Rotary Positional Embedding

Rotary Positional Embedding (RoPE) injects positional information directly into the attention computation by rotating half of the embedding dimensions. In MLA:

- **Caching and Efficiency:**
The `DeepSeekV2RotaryEmbedding` module precomputes and caches cosine and sine values for efficiency.

- **Application:**
The function `apply_rotary_pos_emb` applies these precomputed values to the “rope” part of the queries and keys.

- **Decoupling:**
By separating the positional (rope) component from the non-positional (nope) component, MLA ensures that the low-rank compression is not disrupted by position-dependent rotations.

### Attention Computation

After re-projecting and splitting, the MLA layer computes attention as follows:

1. **Concatenation:**
The “nope” and “rope” parts of the queries and keys are concatenated to form the final query and key tensors.

2. **Scaled Dot-Product:**
Attention weights are computed by performing a scaled dot-product between the query and the transposed key, with scaling by the square root of the query dimension.

3. **Masking and Softmax:**
An optional causal mask (for autoregressive generation) is applied to the attention weights, followed by softmax normalization.

4. **Output Projection:**
The weighted sum over the value states is computed and passed through an output linear projection (`out_proj`) to produce the final attention output.

---

## MLA vs. Standard MHA

| Feature | Standard MHA | DeepSeek MLA |
|-----------------------------|----------------------------------------|---------------------------------------------------------------|
| **KV Cache Storage** | Full keys and values per head | Compressed latent KV vectors (much smaller cache) |
| **Projection Strategy** | Direct projection to Q, K, V | Two-stage (down-projection then up-projection) for low-rank approximation |
| **Positional Encoding** | Added directly (e.g., sinusoidal) | Decoupled RoPE applied only on a split of the representation |
| **Memory Footprint** | High (grows with full dimensionality) | Reduced (low-rank latent compression cuts down storage needs) |
| **Inference Efficiency** | Slower for long sequences due to large KV cache | Faster inference with reduced memory transfers and computational overhead |
| **Expressiveness** | High, but with higher resource costs | Maintains (or improves) performance with fewer activated parameters |

*DeepSeek’s MLA achieves a favorable balance between memory efficiency and model performance, making it especially effective for long-context language models.*
:contentReference[oaicite:0]{index=0}

---

## Additional Engineering Details

- **RMS Norm Variant:**
The customized RMS norm (`DeepSeekV2RMSNorm`) improves training stability by computing variance over the last dimension and applying a learnable scaling factor.

- **Optimized Positional Embeddings:**
Caching cosine and sine values in the rotary embedding module avoids redundant computation, which is especially important for long sequences.

- **LoRA-Inspired Projections:**
The two-stage (down- and up-projection) approach reduces the number of parameters stored in the KV cache and allows dynamic reconstruction of high-dimensional representations.

- **Dropout and Masking:**
Standard dropout is applied to attention weights, and causal masking ensures the autoregressive property is maintained despite latent compression.

- **Modular Design:**
The implementation is modular, allowing individual components (e.g., normalization, rotary embeddings, projection layers) to be easily adapted or replaced for different hardware constraints or alternative compression strategies.

## Usage Example

Below is an outline of how the MLA module is used in a forward pass:

```py
def forward(self, hidden_states, position_ids, attention_mask=None):
# hidden_states: (batch_size, seq_len, hidden_dim)
bsz, q_len, _ = hidden_states.size()

# 1. Query Compression: Down-project, normalize, and up-project.
q = self.q_down_proj(hidden_states)
q = self.q_down_norm(q)
q = self.q_up_proj(q)
# Reshape for multi-head: (batch, num_heads, seq_len, q_head_dim)
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)

# Split query into non-positional and positional parts.
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

# 2. Key-Value Compression: Down-project, normalize, up-project.
c_kv = self.kv_down_proj(hidden_states)
c_kv, k_rope = torch.split(c_kv, [self.kv_lora_rank, self.qk_nope_head_dim], dim=-1)
k_rope = k_rope.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)

kv = self.kv_down_norm(c_kv)
kv = self.kv_up_proj(kv)
kv = kv.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

# 3. Apply Rotary Positional Embedding (RoPE) to positional parts.
seq_len_kv = value_states.size(2)
cos, sin = self.rotary_emb(value_states, seq_len=seq_len_kv)
q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin, position_ids)

# 4. Concatenate both parts and compute attention.
query_states = torch.concat([q_nope, q_rope], dim=-1)
key_states = torch.concat([k_nope, k_rope.expand(-1, self.num_heads, -1, -1)], dim=-1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
attn_weights = attn_weights / math.sqrt(self.q_head_dim)

# 5. Optional: Apply attention mask, softmax, dropout.
if attention_mask is not None:
attn_weights = torch.masked_fill(attn_weights, attention_mask == 0, float('-inf'))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)

# 6. Compute output and apply final projection.
output = torch.matmul(attn_weights, value_states)
output = output.transpose(1, 2).reshape(bsz, q_len, -1)
output = self.out_proj(output)
return output, attn_weights
```

This pseudocode illustrates the core steps in MLA and shows how each module contributes to reducing KV cache size while maintaining effective attention computation.

---

## Conclusion

DeepSeek’s Multi-Head Latent Attention (MLA) introduces an innovative twist on standard multi-head attention by compressing key–value representations into low-dimensional latent vectors. This approach offers:

- **Lower Memory Usage:** Significantly reduced KV cache size is critical for efficient long-context inference.
- **Efficient Computation:** Down- and up-projection layers (inspired by low-rank adaptation) minimize computational overhead.
- **Enhanced Flexibility:** Decoupling positional information using RoPE ensures the model leverages both positional and semantic cues effectively.

By balancing these aspects, MLA achieves performance that is competitive with (or superior to) traditional MHA while substantially reducing resource requirements—a crucial advantage for scaling large language models.

## References

- [LLMs Zero to Hero](https://github.com/bbruceyuan/LLMs-Zero-to-Hero)

---

# Part 2. MiniGPT from Scratch: Technical Documentation

This document explains a simplified implementation of the GPT architecture using PyTorch. The code is designed for educational purposes, demonstrating the core components of the Transformer-based language model. It covers the overall architecture, Transformer principles, a comparison with state-of-the-art GPT models, trade-offs, pros and cons, and additional technical details relevant to LLM engineering.

- **Built a Transformer Architecture from Scratch**: Independently designed and implemented a miniGPT model, showcasing a deep understanding of multi-head self-attention, feed-forward networks, residual connections, and layer normalization.
- **Extensive PyTorch Expertise**: Developed modular components including token and positional embeddings, Transformer blocks, and output projection layers using PyTorch, demonstrating robust deep learning engineering skills.
- **Efficient Data Pre-processing & Training Pipeline**: Engineered a data handling system with JSONL input and leveraged the tiktoken library for GPT-2 vocabulary encoding, facilitating effective loading and segmentation of large-scale text data.
- **Autoregressive Text Generation**: Implemented an autoregressive generation mechanism that uses probability sampling for sequential text generation, validating the model’s capability in natural language generation tasks.
- **Advanced Learning Rate Scheduling & Optimization**: Utilized a cosine annealing learning rate scheduler and integrated checkpoint saving to stabilize training, improve convergence, and enhance model generalization.
- **In-depth Comparative Analysis**: Conducted detailed comparisons between miniGPT and state-of-the-art GPT models regarding architecture, performance, and resource efficiency, reflecting a comprehensive understanding of design trade-offs.
- **Modular & Scalable Software Engineering**: Emphasized clear, maintainable, and extensible code design, illustrating strong software engineering practices and mastery over complex LLM system implementations.

---

## Table of Contents

- [Overview](#overview)
- [Architecture & Technology Stack](#architecture--technology-stack)
- [Transformer Algorithm Principles](#transformer-algorithm-principles)
- [Comparison: MiniGPT vs. Latest GPT Models](#comparison-minigpt-vs-latest-gpt-models)
- [Trade-offs, Pros and Cons](#trade-offs-pros-and-cons)
- [Technical Implementation Details](#technical-implementation-details)
- [Training & Data Pipeline](#training--data-pipeline)
- [Conclusion](#conclusion)

---

## Overview

This miniGPT implementation is a lightweight, from-scratch version of a GPT-style language model. It is intended to help users understand the building blocks of large language models (LLMs) by:

- Demonstrating the core Transformer architecture.
- Implementing key components like multi-head self-attention, feed-forward networks, and positional embeddings.
- Providing a training loop with dataset handling and checkpoint saving.

The code prioritizes clarity and educational value over the raw performance and scale of production models like GPT-3 or GPT-4.

---

## Architecture & Technology Stack

The implementation is based on the following major components:

- **PyTorch**: Used as the primary deep learning framework to implement the neural network modules, training loops, and optimization routines.
- **Dataclass for Configuration**: `GPTConfig` stores hyperparameters such as block size, batch size, number of layers, heads, embedding dimensions, dropout rate, and vocabulary size.
- **Embedding Layers**:
- **Token Embeddings**: Convert token indices into dense vectors.
- **Positional Embeddings**: Encode the order of tokens in the sequence.
- **Transformer Blocks**: Composed of:
- **Multi-Head Self-Attention**: Allows the model to focus on different positions in the input sequence simultaneously.
- **Feed-Forward Network (FFN)**: A two-layer MLP that processes each position independently.
- **Residual Connections and Layer Normalization**: Improve gradient flow and stabilize training.
- **Output Projection**: A linear layer maps the final embeddings back to the vocabulary space for language modeling.
- **Learning Rate Scheduler**: Uses cosine annealing to adjust the learning rate during training.

---

## Transformer Algorithm Principles

The Transformer architecture is the foundation of GPT models. Key principles include:

- **Self-Attention Mechanism**:
- **Query, Key, and Value Projections**: Each input is projected into three vectors. The attention score is computed by taking the dot product between the query and key vectors, then scaled by the square root of the head dimension.
- **Causal Masking**: A lower-triangular mask (`torch.tril`) is applied to ensure that predictions for a given token only depend on previous tokens.
- **Multi-Head Attention**: Multiple attention heads run in parallel, each learning to capture different relationships in the data. The outputs are concatenated and projected back into the model’s embedding space.

- **Feed-Forward Network (FFN)**:
- Consists of two linear transformations separated by a non-linear activation (GELU), along with dropout for regularization.

- **Positional Encoding**:
- Positional embeddings are added to token embeddings to provide the model with a notion of token order, compensating for the permutation-invariance of self-attention.

- **Layer Normalization & Residual Connections**:
- Both are used to stabilize training and facilitate deeper network architectures by normalizing inputs and preserving information through skip connections.

---

## Comparison: MiniGPT vs. Latest GPT Models

### Similarities

- **Architecture**: Both use the Transformer decoder architecture with self-attention mechanisms.
- **Core Components**: Token and positional embeddings, multi-head attention, FFN, layer normalization, and residual connections.
- **Autoregressive Generation**: Text is generated token by token based on previous context.

### Differences

- **Scale**:
- **MiniGPT**: Configured with fewer layers (e.g., 6 layers), heads (12), and a smaller embedding dimension (768). It has a significantly lower number of parameters, making it ideal for educational purposes.
- **Latest GPT Models**: Feature dozens of layers and hundreds of billions of parameters, which allows them to capture more complex patterns and generalize across a wide range of tasks.
- **Performance**:
- MiniGPT is not optimized for state-of-the-art performance and serves as a simplified model. The latest GPT models exhibit superior language understanding and generation capabilities due to their larger scale and extensive pretraining.
- **Training Data & Objectives**:
- Latest models are often trained on vast, diverse corpora with advanced techniques like reinforcement learning from human feedback (RLHF), which are not part of the miniGPT implementation.

---

## Trade-offs, Pros and Cons

### Pros

- **Simplicity**:
- The code is straightforward and modular, making it easy to understand the fundamental components of GPT-style models.
- **Educational Value**:
- Ideal for learning and experimentation with Transformer-based language models.
- **Resource Efficiency**:
- Requires significantly fewer computational resources compared to full-scale GPT models, enabling quick prototyping and experimentation on modest hardware.

### Cons

- **Limited Capacity**:
- The reduced number of layers, heads, and smaller embedding size limit its ability to capture complex language patterns.
- **Performance**:
- Due to its simplicity, it will not perform at the same level as the latest GPT models on complex tasks.
- **Scalability**:
- Techniques and optimizations used in production-level LLMs (such as model parallelism, advanced learning rate schedules, and massive distributed training) are not included.

---

## Technical Implementation Details

### Core Modules

- **`GPTConfig`**:
- Stores model hyperparameters including block size, number of layers/heads, embedding dimensions, dropout rate, and vocabulary size.

- **Attention Mechanisms**:
- **`SingleHeadAttention`**: Implements one attention head with key, query, and value linear projections. Applies a causal mask to ensure autoregressive behavior.
- **`MultiHeadAttention`**: Combines multiple `SingleHeadAttention` modules and projects the concatenated outputs back to the embedding space.

- **Feed-Forward Network (FFN)**:
- A simple MLP with a GELU activation function and dropout.

- **Transformer Block**:
- Combines multi-head attention and FFN with residual connections and layer normalization.

- **Overall GPT Model**:
- Consists of an embedding layer (token and position), a stack of Transformer blocks, a final layer normalization, and a linear output head projecting to the vocabulary size.
- Includes a text generation function (`generate`) that implements autoregressive sampling.

### Additional Implementation Aspects

- **Weight Initialization**:
- Linear and embedding layers are initialized using a normal distribution with a mean of 0 and standard deviation of 0.02.

- **Dropout**:
- Applied in attention and FFN layers to mitigate overfitting.

- **Dataset Handling**:
- **`MyDataset`**:
- Reads a JSONL file containing text data.
- Uses the `tiktoken` library to tokenize input text with GPT-2's vocabulary.
- Segments the tokenized text into chunks that match the block size, appending an end-of-text token as needed.

- **Training Pipeline**:
- Utilizes PyTorch’s `DataLoader` for batching and shuffling.
- Includes a training loop with loss computation, backpropagation, and learning rate scheduling (cosine annealing).
- Supports evaluation on a validation set and checkpoint saving after each epoch.

---

## Training & Data Pipeline

1. **Data Preparation**:
- The dataset is expected to be in a JSONL format where each line is a JSON object containing a `"text"` field.
- The text is tokenized using GPT-2’s encoding (via `tiktoken`) and segmented into fixed-length sequences (blocks).

2. **Training Loop**:
- The training function processes batches of data, computes the cross-entropy loss between predictions and targets, and updates model weights.
- A cosine annealing learning rate scheduler adjusts the learning rate smoothly over epochs.
- Periodic logging helps track training progress, and checkpoints are saved after each epoch for future restoration.

3. **Text Generation**:
- The model’s `generate` function produces new text tokens sequentially by autoregressively sampling from the probability distribution of the next token.

---

## MiniGPT Summary

This miniGPT implementation is an excellent starting point for understanding the inner workings of GPT and Transformer-based language models. While it lacks the scale and complexity of the latest GPT models, it offers valuable insights into:

- The architecture and core components of Transformer models.
- How multi-head attention and feed-forward networks work together.
- The trade-offs between model simplicity and performance.

By exploring and experimenting with this code, developers can gain foundational knowledge that paves the way for more advanced LLM engineering projects.

---

*References for further reading on Transformer and GPT architectures can be found in the seminal paper "[Attention Is All You Need](https://arxiv.org/abs/1706.03762)" and subsequent literature on GPT models.*