https://github.com/naidezhujimo/yingret
https://github.com/naidezhujimo/yingret
Last synced: about 1 month ago
JSON representation
- Host: GitHub
- URL: https://github.com/naidezhujimo/yingret
- Owner: naidezhujimo
- Created: 2025-04-05T06:53:28.000Z (about 2 months ago)
- Default Branch: main
- Last Pushed: 2025-04-05T06:54:31.000Z (about 2 months ago)
- Last Synced: 2025-04-09T19:07:55.924Z (about 1 month ago)
- Language: Python
- Size: 5.86 KB
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# RetNet: Efficient Sequence Modeling with Dual-Paradigm Attention
RetNet is a novel neural architecture that combines **parallel computation for training** and **recurrent computation for inference**, featuring a retention mechanism with exponential decay. This implementation demonstrates efficient sequence modeling through hybrid computation modes and hardware-aware optimizations.
## Key Features
- 🌀 **Dual Computation Modes**:
- **Parallel Mode**: Full-sequence attention for training
- **Recurrent Mode**: O(1) memory inference
- 📉 **Exponential Decay Matrix**: Position-aware attention decay
- 🧠 **Enhanced Value Projection**: Optional double-dimensional V vectors
- ⚖️ **Group Normalization**: Head-wise normalization代替LayerNorm
- 🧩 **Modular Design**: Plug-and-play RetNet blocks## Installation
```bash
git clone https://github.com/yourusername/retnet.git
cd retnet
pip install torch
```## Usage
### Basic Model Initialization
```python
from model import RetNetmodel = RetNet(
n_layers=6,
d_model=512,
n_heads=8,
vocab_size=32000
).cuda()# Training mode (parallel computation)
output = model(x, mode='parallel') # Input shape: [B, L]# Inference mode (recurrent computation)
output = model(x, mode='recurrent')
```### Retention Mechanism Configuration
```python
retention = Retention(
d_model=512,
n_heads=8,
double_v_dim=True # Enable enhanced value projection
)# Generate decay matrix for sequence length 1024
D = retention.get_decay_matrix(1024, device='cuda') # [8, 1024, 1024]
```## Model Architecture
| Component | Specification |
|-------------------------|----------------------------------------|
| Hidden Dimension | 512 |
| Attention Heads | 8 |
| FFN Intermediate Dim | 2048 |
| Default Layers | 6 |
| Value Dimension | 1024 (when double_v_dim=True) |## Core Implementations
### Retention Mechanism
```python
def forward_parallel(self, Q, K, V):
# Multi-head splitting
Q = Q.view(B, L, H, D).transpose(1, 2) # [B, H, L, D]
# Compute decayed attention
attn = (Q @ K.transpose(-2, -1)) * self.scale
attn = attn * D.unsqueeze(0) # Apply head-specific decay
```### Recurrent Mode
```python
def forward_recurrent(self, Q, K, V):
# State maintenance
state = gamma * state + torch.einsum('bhd,bhe->bhde', Kt, Vt)
# Output computation
output = torch.einsum('bhd,bhde->bhe', Qt, state)
```## Training Configuration
- **Normalization**: GroupNorm over LayerNorm
- **Initialization**:
- Xavier for linear layers
- Learned decay parameters (γ)
- **Value Projection**:
- Default 2x dimension expansion
- Disable with `double_v_dim=False`## Performance
| Mode | Memory Complexity | Typical Use Case |
|------------|--------------------|-----------------------|
| Parallel | O(L²) | Training |
| Recurrent | O(1) | Inference/Deployment |## License
[MIT License](LICENSE) - Open for academic and commercial use.---
**Note**: For production deployment:
1. Add mixed-precision training support
2. Implement gradient checkpointing
3. Add cross-sequence batch support
4. Consider CUDA kernel optimization for recurrent mode
```