An open API service indexing awesome lists of open source software.

https://github.com/matt-k-wong/mlx-flash

Flash weight streaming for MLX: run massive models larger than your RAM on Apple Silicon.
https://github.com/matt-k-wong/mlx-flash

apple-silicon large-language-models llm llm-inference lm-studio machine-learning macos memory-optimization metal mlx optimization weight-streaming

Last synced: 9 days ago
JSON representation

Flash weight streaming for MLX: run massive models larger than your RAM on Apple Silicon.

Awesome Lists containing this project

README

          

# mlx-flash ⚡

> **Flash Weight Streaming for MLX** — run models larger than your RAM on Apple Silicon.
> 30B on 16 GB, 70B+ on 32 GB+. **No additional quantisation — uses the model's native precision.**

> **Project Lineage:** This implementation is inspired by Apple Research's paper [*LLM in a Flash* (arXiv 2312.11514)](https://arxiv.org/abs/2312.11514). `mlx-flash` provides a high-quality, production-grade integration layer for the MLX ecosystem, featuring bit-perfect parity and predictive bandwidth scheduling.

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
[![Python 3.11+](https://img.shields.io/badge/Python-3.11%2B-blue.svg)](https://python.org)
[![MLX](https://img.shields.io/badge/MLX-latest-green.svg)](https://github.com/ml-explore/mlx)
[![macOS 13+](https://img.shields.io/badge/macOS-13%2B-lightgrey.svg)](https://apple.com)
[![Tests](https://github.com/matt-k-wong/mlx-flash/actions/workflows/tests.yml/badge.svg)](https://github.com/matt-k-wong/mlx-flash/actions/workflows/tests.yml)

---

## Why Flash Mode?

| Model | Hardware | Mode | Load Time | Result |
|-------|----------|------|-----------|--------|
| **Nemotron-30B (17.8 GB)** | 16GB MacBook Air | Normal | 4.1s | ❌ OOM / Laggy |
| **Nemotron-30B (17.8 GB)** | 16GB MacBook Air | **Flash** | **0.8s** | ✅ Bit-Perfect & Smooth |

`mlx-flash` allows you to run models of **any size** (30B, 70B, even 400B+) on base-spec Macs by streaming weights directly from your SSD.

---

## 🏗️ Architecture: The Holistic Patch

Unlike previous iterations that attempted to re-implement the transformer loop manually, `mlx-flash` now uses a **Holistic Model Patching** architecture.

1. **Deep Tissue Patching**: We wrap the original model's layers in a `StreamingProxy`.
2. **Native Logic Retention**: Because we use the model's own `__call__` method, every nuance of the architecture (RoPE scaling, residual streams, causal masks) is handled natively by the model code.
3. **Execution Interception**: Our proxies intercept the layer execution to force synchronous `mx.eval()` and trigger the **Predictive I/O Scheduler**.

### The Control Loop (MPC-Lite)
We use a **Model Predictive Controller** to maximize tokens/second:
- **Baseline Estabishment**: On the first token ("Cold Start"), we establish a pristine compute baseline.
- **Predictive Prefetch**: We predict the bandwidth demand of Layer N+1 while the GPU is still busy with Layer N.
- **Token Bucket Actuator**: A continuous token bucket smoothly paces SSD reads using micro-sleeps, keeping GPU degradation below 5%.

---

## 🏆 Quality & Bit-Parity

`mlx-flash` is a **zero-compromise** engine. We have proven quality through:

1. **Bit-Perfect Operators**: `TiledLinear` executes identically to `nn.Linear` (fused `mx.addmm`), so the loss delta vs. standard MLX is **exactly 0**. Note: on MLX ≥ 0.31, Metal kernel selection makes block-wise tiled accumulation diverge from native fp16 matmul, so bit-exact mode executes layers whole; sub-layer tiling will return as an opt-in memory mode.
2. **Hybrid KV Cache**: Keeps the most recent **128 tokens in full FP16 precision**, while offloading older context to properly scaled 8-bit quantized disk storage.
3. **Passkey Retrieval**: Verified 100% accuracy on context retrieval tests hidden 1,000+ tokens deep in quantized disk storage.

See [QUALITY.md](docs/QUALITY.md) for the full proof suite.

---

## 🚀 Quick Start

### 1. Install
```bash
pip install git+https://github.com/matt-k-wong/mlx-flash.git
```

> ⚠️ **Do not `pip install mlx-flash`** — the PyPI package by that name is an **unrelated project**. This project is installed from GitHub. Tested against `mlx>=0.31` / `mlx-lm>=0.31`.

### 2. Unified CLI
```bash
# Run any model with 2GB weight residence budget
mlx-flash --model mlx-community/Llama-3.2-1B-Instruct-4bit --ram 2.0 --kv-quant 8
```

### 3. Python Usage
```python
from mlx_flash import FlashConfig, FlashManager

# 1. Load and Patch
manager = FlashManager(FlashConfig(ram_budget_gb=2.0))
model, tokenizer = manager.load("mlx-community/Meta-Llama-3-70B-Instruct-4bit")

# 2. Generate
for segment in model.stream_generate("Tell me a story", max_tokens=100):
print(segment, end="", flush=True)
```

---

## How It Works

```mermaid
graph TD
A[SSD: .safetensors] --"mmap(lazy=True)"--> B[MLX Lazy Arrays]
A --"Predictive Worker"--> P[Token Bucket]
P --"os.pread"--> B

subgraph Model["Native Model Logic"]
Embed --> Proxy1
subgraph Proxy1["StreamingProxy (Layer 1)"]
StartHook --> Dispatch[strategy.execute]
Dispatch --> Eval[mx.eval]
Eval --> EndHook
end
Proxy1 --> Proxy2[...]
Proxy2 --> Norm
Norm --> Head
end
```

---

## Roadmap
- [x] **v0.4.0**: Holistic Model Patching (Bit-Perfect Parity), MPC-Lite Bandwidth Controller, Unified `mlx-flash` CLI, `mlx`/`mlx-lm` 0.31+ compatibility.
- [ ] **v0.5.0**: Asynchronous DAG Scheduler (Zero-latency Python glue).
- [ ] **v0.6.0**: MoE Lookahead Routing for Mixtral/DeepSeek.

---

*Brought to you by ⚡ Flash-Mode Contributors. MIT licensed.*