https://github.com/ginozza/shrew
A polyglot neural engine for modular, language-agnostic deep learning with decoupled execution.
https://github.com/ginozza/shrew
deep-learning dls gpu machine-learning python rust tensor transformer
Last synced: 4 months ago
JSON representation
A polyglot neural engine for modular, language-agnostic deep learning with decoupled execution.
- Host: GitHub
- URL: https://github.com/ginozza/shrew
- Owner: ginozza
- License: other
- Created: 2026-02-14T08:17:51.000Z (4 months ago)
- Default Branch: master
- Last Pushed: 2026-02-18T03:44:22.000Z (4 months ago)
- Last Synced: 2026-02-18T07:14:46.381Z (4 months ago)
- Topics: deep-learning, dls, gpu, machine-learning, python, rust, tensor, transformer
- Language: Rust
- Homepage: https://shrew.ink
- Size: 1.05 MB
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Changelog: CHANGELOG.md
- Contributing: CONTRIBUTING.md
- Funding: .github/funding.yml
- License: LICENSE
- Code of conduct: CODE_OF_CONDUCT.md
- Security: SECURITY.md
- Notice: NOTICE
Awesome Lists containing this project
README
---
Shrew is a modular deep learning framework written from scratch in Rust. It provides a tensor system with automatic differentiation, neural network layers, optimizers, a CUDA GPU backend, and a declarative intermediate representation (`.sw`) for language-agnostic model specification.
Models defined in `.sw` files can be trained from Python (via PyO3 bindings) or Rust, and deployed with zero transpilation across platforms.
## Components
| Crate | Description |
|-------|-------------|
| **shrew-core** | `Tensor`, `Shape`, `DType`, `Layout`, `Backend` trait, reverse-mode autograd, dynamic symbolic shapes |
| **shrew-cpu** | CPU backend: SIMD matmul via `gemm` (AVX2/AVX-512/FMA), parallel ops via `rayon`, broadcasting |
| **shrew-cuda** | NVIDIA GPU backend: cuBLAS matmul, custom PTX kernels, memory pool, mixed-precision F16/BF16 |
| **shrew-nn** | Neural network layers: Linear, Conv1d/2d, RNN/LSTM/GRU, MultiHeadAttention, Transformer, BatchNorm, LayerNorm, losses |
| **shrew-optim** | Optimizers (SGD, Adam, AdamW, RAdam, RMSProp), LR schedulers, gradient clipping, EMA |
| **shrew-ir** | `.sw` format: lexer, parser, AST, Graph IR, lowering, validation, shape inference, optimization passes |
| **shrew-data** | `Dataset` trait, `DataLoader`, MNIST, image transforms, async prefetch loader |
| **shrew** | Facade crate: executor, JIT compiler, trainer, distributed training, quantization, ONNX, profiling, checkpoints |
| **shrew-python** | Python bindings via PyO3 with NumPy interop |
| **shrew-cli** | CLI tools: `shrew dump`, `validate`, `bench`, `info` |
## Technical Features
### Backend-Agnostic Tensor System
`Tensor` is generic over `Backend`. The same tensor code runs on CPU and GPU without changes. Supported dtypes: `F16`, `BF16`, `F32`, `F64`, `U8`, `U32`, `I64`.
```rust
use shrew::prelude::*;
let dev = CpuDevice;
let a = CpuTensor::randn((3, 4), DType::F32, &dev)?;
let b = CpuTensor::randn((4, 5), DType::F32, &dev)?;
let c = a.matmul(&b)?; // [3,4] × [4,5] → [3,5]
```
Operations: `add`, `sub`, `mul`, `div` (with NumPy-style broadcasting), `neg`, `abs`, `exp`, `log`, `sqrt`, `square`, `sin`, `cos`, `relu`, `sigmoid`, `tanh`, `gelu`, `silu`, `softmax`, `log_softmax`, `matmul`, `reshape`, `transpose`, `narrow`, `unsqueeze`, `expand`, `cat`, `chunk`, `index_select`, `sum`, `mean`, `max`, `min`, `argmax`, `argmin`, `var`, comparisons (`eq`, `ne`, `gt`, `ge`, `lt`, `le`), `to_dtype`.
### Reverse-Mode Automatic Differentiation
Eager autograd — every op records its computational graph. `backward()` does topological sort and applies the chain rule. Gradient paths cover all binary/unary ops, reductions, matmul, reshape, transpose, narrow, affine, contiguous, cat, and index_select.
```rust
let w = CpuTensor::randn((3, 3), DType::F64, &dev)?.set_variable();
let x = CpuTensor::randn((2, 3), DType::F64, &dev)?;
let loss = x.matmul(&w)?.sum_all()?;
let grads = loss.backward()?;
let dw = grads.get(&w).unwrap(); // ∂loss/∂w
```
### Neural Network Layers
All layers implement `Module::forward()` and are generic over `Backend`:
| Category | Layers |
|----------|--------|
| Dense | `Linear` |
| Convolution | `Conv1d`, `Conv2d`, `MaxPool2d`, `AvgPool2d`, `AdaptiveAvgPool2d` |
| Recurrent | `RNNCell`, `RNN`, `LSTMCell`, `LSTM`, `GRUCell`, `GRU` |
| Attention | `MultiHeadAttention`, `TransformerBlock` |
| Normalization | `BatchNorm2d`, `LayerNorm`, `GroupNorm`, `RMSNorm` |
| Embedding | `Embedding` |
| Regularization | `Dropout`, `Flatten`, `Sequential` |
| Activations | `ReLU`, `GeLU`, `SiLU`, `LeakyReLU`, `ELU`, `Mish` |
| Losses | `mse_loss`, `cross_entropy_loss`, `l1_loss`, `smooth_l1_loss`, `bce_loss`, `bce_with_logits_loss`, `nll_loss` |
### Optimizers and Schedulers
| Optimizers | Schedulers |
|------------|-----------|
| `SGD` (momentum, weight decay) | `StepLR`, `ExponentialLR`, `LinearLR` |
| `Adam`, `AdamW`, `RAdam` | `CosineAnnealingLR`, `CosineWarmupLR` |
| `RMSProp` | `ReduceLROnPlateau` |
Utilities: `clip_grad_norm`, `clip_grad_value`, `grad_norm`, `GradAccumulator`, `EMA`.
### CUDA GPU Backend
Feature-gated backend using `cudarc`. cuBLAS for matrix multiplication, custom PTX kernels for elementwise, reduction, broadcast, and cast operations. Includes a memory pool with allocation reuse.
```bash
cargo build -p shrew --features cuda
```
Mixed-precision training: `MixedPrecisionTrainer` with dynamic loss scaling, automatic F32↔F16/BF16 casting via `to_dtype`.
### `.sw` Intermediate Representation
Declarative, text-based model specification — separates model architecture from runtime execution:
```sw
@model { name: "TinyGPT"; }
@config {
d_model: 256;
n_heads: 4;
d_ff: 256 * 4; // constant folding → 1024
}
@graph Forward {
input tokens: Tensor<[Batch, SeqLen], i64>;
param wte: Tensor<[50257, 256], f32> { init: "normal(0, 0.02)"; };
param wpe: Tensor<[512, 256], f32> { init: "normal(0, 0.02)"; };
node tok_emb { op: embedding(wte, tokens); };
node pos_emb { op: embedding(wpe, positions); };
node h { op: tok_emb + pos_emb; };
node tf_out { op: repeat(4) { transformer_block(h, n_heads: 4); }; };
node ln_out { op: layer_norm(tf_out, ln_w, ln_b, eps: 1e-5); };
node logits { op: matmul(ln_out, transpose(wte)); };
output logits;
}
@training {
loss: cross_entropy;
optimizer: { type: "AdamW"; lr: 3e-4; weight_decay: 0.1; }
epochs: 20;
batch_size: 64;
}
```
Pipeline: source → **Lexer** → tokens → **Parser** → AST → **Lowering** → Graph IR → **Validate** → **Shape inference** → **Optimize** (DCE, CSE, constant folding, operator fusion, identity elimination).
### JIT Compilation
`JitExecutor` compiles IR graphs into a flat instruction tape with pre-allocated memory slots and value lifetime tracking. No re-interpretation of the graph at runtime.
```rust
use shrew::exec::jit::load_jit;
let executor = load_jit::(sw_source, CpuDevice, config)?;
let result = executor.run("Forward", &inputs)?;
```
### Dynamic Symbolic Shapes
`SymDim` (Fixed/Symbolic/Dynamic), `SymbolicShape`, `ShapeEnv`, and `ShapeGuard` bridge symbolic IR shapes with runtime concrete shapes. Supports shape unification, matching, and broadcasting.
### Distributed Training
| Component | Description |
|-----------|-------------|
| `DataParallel` | Batch splitting across workers with output concatenation |
| `PipelineParallel` | GPipe-style micro-batch pipelining |
| `MixedPrecisionTrainer` | Dynamic loss scaling for FP16/BF16 |
| `reduce_gradients` | All-reduce gradient synchronization |
### Quantization
INT8/INT4 post-training quantization (symmetric/asymmetric, per-tensor/per-channel). `QuantizedLinear` for dequantize-on-the-fly inference.
### ONNX Interop
Export/import ONNX models (opset 17) with a built-in minimal protobuf encoder/decoder (zero external dependencies).
### Profiling
`Profiler` with named timing events, `MemoryTracker`, `ModelSummary`, `benchmark_forward`, `benchmark_forward_backward`.
### Serialization
| Format | Description |
|--------|-------------|
| `.shrew` | Native binary checkpoint (`save_tensors` / `load_tensors`) |
| Safetensors | HuggingFace-compatible (`save_safetensors` / `load_safetensors`) |
| ONNX | Open Neural Network Exchange (`export_weights` / `load_onnx_weights`) |
## Installation
### Rust
```toml
# Cargo.toml
[dependencies]
shrew = "0.1"
```
With CUDA support:
```toml
[dependencies]
shrew = { version = "0.1", features = ["cuda"] }
```
Or directly from GitHub:
```toml
[dependencies]
shrew = { git = "https://github.com/ginozza/shrew" }
```
### CLI
```bash
cargo install shrew-cli
```
Or from source:
```bash
cargo install --git https://github.com/ginozza/shrew shrew-cli
```
This installs the `shrew` binary with commands: `dump`, `validate`, `bench`, `info`.
### Python
```bash
pip install shrew-python
```
Or build from source:
```bash
git clone https://github.com/ginozza/shrew
cd shrew
pip install maturin
maturin develop --release
```
```python
import shrew_python as shrew
t = shrew.tensor([1.0, 2.0, 3.0])
print(t)
```
### From Source (full workspace)
```bash
git clone https://github.com/ginozza/shrew
cd shrew
cargo build --workspace
cargo test --workspace
```
Requirements:
- Rust 1.75+ (edition 2021)
- Python 3.9+ (for Python bindings)
- NVIDIA CUDA Toolkit (for GPU backend only)
## Getting Started
### Tensor operations and autograd
```rust
use shrew::prelude::*;
fn main() -> shrew::Result<()> {
let dev = CpuDevice;
// Broadcasting: [3,1] + [1,2] → [3,2]
let a = CpuTensor::from_f64_slice(&[1.0, 2.0, 3.0], (3, 1), DType::F64, &dev)?;
let b = CpuTensor::from_f64_slice(&[10.0, 20.0], (1, 2), DType::F64, &dev)?;
let c = a.add(&b)?;
// Transformer forward pass
let block = TransformerBlock::::new(64, 4, 256, true, DType::F64, &dev)?;
let x = CpuTensor::rand((2, 10, 64), DType::F64, &dev)?;
let y = block.forward(&x)?; // [2,10,64]
// Autograd
let w = CpuTensor::rand((3, 3), DType::F64, &dev)?.set_variable();
let input = CpuTensor::rand((2, 3), DType::F64, &dev)?;
let loss = input.matmul(&w)?.sum_all()?;
let grads = loss.backward()?;
Ok(())
}
```
### Executing a `.sw` model
```rust
use shrew::prelude::*;
use shrew::exec::{load_program, RuntimeConfig};
let src = r#"
@model { name: "MLP"; }
@graph Forward {
input x: Tensor<[2, 4], f32>;
param w: Tensor<[4, 3], f32> { init: "normal(0, 0.1)"; };
node out { op: softmax(matmul(x, w)); };
output out;
}
"#;
let config = RuntimeConfig::default().with_dtype(DType::F32);
let exec = load_program::(src, CpuDevice, config)?;
let x = CpuTensor::rand((2, 4), DType::F32, &CpuDevice)?;
let mut inputs = std::collections::HashMap::new();
inputs.insert("x".to_string(), x);
let result = exec.run("Forward", &inputs)?;
let probs = result.get("out").unwrap();
assert_eq!(probs.dims(), &[2, 3]);
```
### Quantization
```rust
use shrew::prelude::*;
let model = Linear::::new(256, 128, true, DType::F32, &CpuDevice)?;
let config = QuantConfig::int8_per_channel();
let quantized = quantize_named_parameters::(&model, &config)?;
```
### Benchmarking
```rust
use shrew::prelude::*;
let model = Linear::::new(512, 256, true, DType::F32, &CpuDevice)?;
let result = benchmark_forward(
&model,
|| Tensor::::rand((32, 512), DType::F32, &CpuDevice).unwrap(),
32, 5, 100,
)?;
println!("{}", result);
```
## Build & Test
```bash
cargo build --workspace # Build all crates
cargo test --workspace # Run all tests (~600+)
cargo clippy --workspace # Lint
cargo fmt --all --check # Format check
cargo doc --workspace --no-deps # Generate documentation
```
### Examples
```bash
cargo run -p example-linear-regression
cargo run -p example-mlp-xor
cargo run -p mnist-example # Requires MNIST data download
cargo run -p mnist-cnn-example
cargo run --release -p char-gpt-example # Char-level GPT on Shakespeare
cargo run -p example-rnn-sequence
cargo run --release -p example-bench-ops # CPU performance benchmarks
```
### CPU Performance (release mode)
| Operation | Size | Time |
|-----------|------|------|
| matmul | 256×256 × 256×256 | ~370 µs |
| matmul | 512×512 × 512×512 | ~3.5 ms |
| add | 1M elements | ~1.3 ms |
| linear forward | [64,512]×[512,512]+bias | ~3.9 ms |
## Dependencies
| Crate | Purpose |
|-------|---------|
| `gemm` | SIMD-accelerated matmul (auto AVX2/AVX-512/FMA) |
| `rayon` | Parallel iteration for batched ops |
| `half` | F16/BF16 with num-traits |
| `cudarc` | CUDA driver/runtime, cuBLAS (optional) |
| `pyo3` / `numpy` | Python bindings (optional) |
| `num-traits` | Numeric trait bounds |
| `rand` / `rand_distr` | Random initialization |
| `thiserror` | Error types |
| `serde_json` | Checkpoint metadata |
## Contributing
Bug fixes are welcome without prior discussion. For new features or architectural changes, please open an issue first. See the [CHANGELOG](CHANGELOG.md) for release history.
## License
Apache-2.0. See [LICENSE](LICENSE) for details.