https://github.com/matt-k-wong/mlx-flash
Flash weight streaming for MLX: run massive models larger than your RAM on Apple Silicon.
https://github.com/matt-k-wong/mlx-flash
apple-silicon large-language-models llm llm-inference lm-studio machine-learning macos memory-optimization metal mlx optimization weight-streaming
Last synced: 9 days ago
JSON representation
Flash weight streaming for MLX: run massive models larger than your RAM on Apple Silicon.
- Host: GitHub
- URL: https://github.com/matt-k-wong/mlx-flash
- Owner: matt-k-wong
- License: mit
- Created: 2026-03-20T17:04:42.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2026-04-01T08:14:49.000Z (3 months ago)
- Last Synced: 2026-04-02T04:39:55.860Z (3 months ago)
- Topics: apple-silicon, large-language-models, llm, llm-inference, lm-studio, machine-learning, macos, memory-optimization, metal, mlx, optimization, weight-streaming
- Language: Python
- Homepage: https://github.com/matt-k-wong/mlx-flash
- Size: 509 KB
- Stars: 73
- Watchers: 0
- Forks: 6
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Changelog: CHANGELOG.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
- Citation: CITATION.cff
- Security: SECURITY.md
- Roadmap: ROADMAP.md
Awesome Lists containing this project
- awesome-mlx - mlx-flash
README
# mlx-flash ⚡
> **Flash Weight Streaming for MLX** — run models larger than your RAM on Apple Silicon.
> 30B on 16 GB, 70B+ on 32 GB+. **No additional quantisation — uses the model's native precision.**
> **Project Lineage:** This implementation is inspired by Apple Research's paper [*LLM in a Flash* (arXiv 2312.11514)](https://arxiv.org/abs/2312.11514). `mlx-flash` provides a high-quality, production-grade integration layer for the MLX ecosystem, featuring bit-perfect parity and predictive bandwidth scheduling.
[](LICENSE)
[](https://python.org)
[](https://github.com/ml-explore/mlx)
[](https://apple.com)
[](https://github.com/matt-k-wong/mlx-flash/actions/workflows/tests.yml)
---
## Why Flash Mode?
| Model | Hardware | Mode | Load Time | Result |
|-------|----------|------|-----------|--------|
| **Nemotron-30B (17.8 GB)** | 16GB MacBook Air | Normal | 4.1s | ❌ OOM / Laggy |
| **Nemotron-30B (17.8 GB)** | 16GB MacBook Air | **Flash** | **0.8s** | ✅ Bit-Perfect & Smooth |
`mlx-flash` allows you to run models of **any size** (30B, 70B, even 400B+) on base-spec Macs by streaming weights directly from your SSD.
---
## 🏗️ Architecture: The Holistic Patch
Unlike previous iterations that attempted to re-implement the transformer loop manually, `mlx-flash` now uses a **Holistic Model Patching** architecture.
1. **Deep Tissue Patching**: We wrap the original model's layers in a `StreamingProxy`.
2. **Native Logic Retention**: Because we use the model's own `__call__` method, every nuance of the architecture (RoPE scaling, residual streams, causal masks) is handled natively by the model code.
3. **Execution Interception**: Our proxies intercept the layer execution to force synchronous `mx.eval()` and trigger the **Predictive I/O Scheduler**.
### The Control Loop (MPC-Lite)
We use a **Model Predictive Controller** to maximize tokens/second:
- **Baseline Estabishment**: On the first token ("Cold Start"), we establish a pristine compute baseline.
- **Predictive Prefetch**: We predict the bandwidth demand of Layer N+1 while the GPU is still busy with Layer N.
- **Token Bucket Actuator**: A continuous token bucket smoothly paces SSD reads using micro-sleeps, keeping GPU degradation below 5%.
---
## 🏆 Quality & Bit-Parity
`mlx-flash` is a **zero-compromise** engine. We have proven quality through:
1. **Bit-Perfect Operators**: `TiledLinear` executes identically to `nn.Linear` (fused `mx.addmm`), so the loss delta vs. standard MLX is **exactly 0**. Note: on MLX ≥ 0.31, Metal kernel selection makes block-wise tiled accumulation diverge from native fp16 matmul, so bit-exact mode executes layers whole; sub-layer tiling will return as an opt-in memory mode.
2. **Hybrid KV Cache**: Keeps the most recent **128 tokens in full FP16 precision**, while offloading older context to properly scaled 8-bit quantized disk storage.
3. **Passkey Retrieval**: Verified 100% accuracy on context retrieval tests hidden 1,000+ tokens deep in quantized disk storage.
See [QUALITY.md](docs/QUALITY.md) for the full proof suite.
---
## 🚀 Quick Start
### 1. Install
```bash
pip install git+https://github.com/matt-k-wong/mlx-flash.git
```
> ⚠️ **Do not `pip install mlx-flash`** — the PyPI package by that name is an **unrelated project**. This project is installed from GitHub. Tested against `mlx>=0.31` / `mlx-lm>=0.31`.
### 2. Unified CLI
```bash
# Run any model with 2GB weight residence budget
mlx-flash --model mlx-community/Llama-3.2-1B-Instruct-4bit --ram 2.0 --kv-quant 8
```
### 3. Python Usage
```python
from mlx_flash import FlashConfig, FlashManager
# 1. Load and Patch
manager = FlashManager(FlashConfig(ram_budget_gb=2.0))
model, tokenizer = manager.load("mlx-community/Meta-Llama-3-70B-Instruct-4bit")
# 2. Generate
for segment in model.stream_generate("Tell me a story", max_tokens=100):
print(segment, end="", flush=True)
```
---
## How It Works
```mermaid
graph TD
A[SSD: .safetensors] --"mmap(lazy=True)"--> B[MLX Lazy Arrays]
A --"Predictive Worker"--> P[Token Bucket]
P --"os.pread"--> B
subgraph Model["Native Model Logic"]
Embed --> Proxy1
subgraph Proxy1["StreamingProxy (Layer 1)"]
StartHook --> Dispatch[strategy.execute]
Dispatch --> Eval[mx.eval]
Eval --> EndHook
end
Proxy1 --> Proxy2[...]
Proxy2 --> Norm
Norm --> Head
end
```
---
## Roadmap
- [x] **v0.4.0**: Holistic Model Patching (Bit-Perfect Parity), MPC-Lite Bandwidth Controller, Unified `mlx-flash` CLI, `mlx`/`mlx-lm` 0.31+ compatibility.
- [ ] **v0.5.0**: Asynchronous DAG Scheduler (Zero-latency Python glue).
- [ ] **v0.6.0**: MoE Lookahead Routing for Mixtral/DeepSeek.
---
*Brought to you by ⚡ Flash-Mode Contributors. MIT licensed.*