https://github.com/lightonai/pylate-rs
PyLate efficient inference engine
https://github.com/lightonai/pylate-rs
colbert late-interaction python rust wasm
Last synced: 3 months ago
JSON representation
PyLate efficient inference engine
- Host: GitHub
- URL: https://github.com/lightonai/pylate-rs
- Owner: lightonai
- License: mit
- Created: 2025-07-01T12:04:17.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2025-07-04T09:43:06.000Z (3 months ago)
- Last Synced: 2025-07-04T09:45:01.012Z (3 months ago)
- Topics: colbert, late-interaction, python, rust, wasm
- Language: Rust
- Homepage: https://lightonai.github.io/pylate-rs/
- Size: 0 Bytes
- Stars: 4
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
pylate-rs
Efficient Inference for PyLate
## ⭐️ Overview
**pylate-rs** is a high-performance inference engine for [PyLate](https://github.com/lightonai/pylate) models, meticulously crafted in Rust for optimal speed and efficiency.
While model training is handled by PyLate, which supports a variety of late interaction models, `pylate-rs` is engineered to execute these models at speeds.
- **Accelerated Performance**: Experience significantly faster model loading and rapid cold starts, making it ideal for serverless environments and low-latency applications.
- **Lightweight Design**: Built on the [Candle](https://github.com/huggingface/candle) ML framework, `pylate-rs` maintains a minimal footprint suitable for resource-constrained systems like serverless functions and edge computing.
- **Broad Hardware Support**: Optimized for diverse hardware, with dedicated builds for standard CPUs, Intel (MKL), Apple Silicon (Accelerate & Metal), and NVIDIA GPUs (CUDA).
- **Cross-Platform Integration**: Seamlessly integrate `pylate-rs` into your projects with bindings for Python, Rust, and JavaScript/WebAssembly.
For a complete, high-performance multi-vector search pipeline, pair `pylate-rs` with its companion library, [**FastPlaid**](https://github.com/lightonai/fast-plaid), at inference time.
Explore our [**WebAssembly live demo**](https://lightonai.github.io/pylate-rs/).
## 💻 Installation
Install the version of `pylate-rs` that matches your hardware for optimal performance.
### Python
| Target Hardware | Installation Command |
| :----------------------- | :------------------------------------------------------- |
| **Standard CPU** | `pip install pylate-rs` |
| **Apple CPU** (macOS) | `pip install pylate-rs-accelerate` |
| **Intel CPU** (MKL) | `pip install pylate-rs-mkl` |
| **Apple GPU** (M1/M2/M3) | `pip install pylate-rs-metal` |
| **NVIDIA GPU** (CUDA) | `pip install git+https://github.com/lightonai/pylate-rs` |Cuda wheels are not yet available on PyPI because of their size. But you can install them with the link to the repository. WIP.
### Rust
Add `pylate-rs` to your `Cargo.toml` by enabling the feature flag that corresponds to your backend.
| Feature | Target Hardware | Installation Command |
| :----------- | :----------------------- | :------------------------------------------ |
| _(default)_ | **Standard CPU** | `cargo add pylate-rs` |
| `accelerate` | **Apple CPU** (macOS) | `cargo add pylate-rs --features accelerate` |
| `mkl` | **Intel CPU** (MKL) | `cargo add pylate-rs --features mkl` |
| `metal` | **Apple GPU** (M1/M2/M3) | `cargo add pylate-rs --features metal` |
| `cuda` | **NVIDIA GPU** (CUDA) | `cargo add pylate-rs --features cuda` |
## ⚡️ Quick Start
### Python
Get started in just a few lines of Python.
```python
from pylate_rs import models# Initialize the model for your target device ("cpu", "cuda", or "mps")
model = models.ColBERT(
model_name_or_path="lightonai/GTE-ModernColBERT-v1",
device="cuda"
)# Encode queries and documents
queries_embeddings = model.encode(
sentences=["What is the capital of France?", "How big is the sun?"],
is_query=True
)documents_embeddings = model.encode(
sentences=["Paris is the capital of France.", "The sun is a star."],
is_query=False
)# Calculate similarity scores
similarities = model.similarity(queries_embeddings, documents_embeddings)print(f"Similarity scores:\n{similarities}")
# Use hierarchical pooling to reduce document embedding size and speed up downstream tasks
pooled_documents_embeddings = model.encode(
sentences=["Paris is the capital of France.", "The sun is a star."],
is_query=False,
pool_factor=2, # Halves the number of token embeddings
)similarities_pooled = model.similarity(queries_embeddings, pooled_documents_embeddings)
print(f"Similarity scores with pooling:\n{similarities_pooled}")
```
### Rust
```rust
use anyhow::Result;
use candle_core::Device;
use pylate_rs::{hierarchical_pooling, ColBERT};fn main() -> Result<()> {
// Set the device (e.g., Cpu, Cuda, Metal)
let device = Device::Cpu;// Initialize the model
let mut model: ColBERT = ColBERT::from("lightonai/GTE-ModernColBERT-v1")
.with_device(device)
.try_into()?;// Encode queries and documents
let queries = vec!["What is the capital of France?".to_string()];
let documents = vec!["Paris is the capital of France.".to_string()];let query_embeddings = model.encode(&queries, true)?;
let document_embeddings = model.encode(&documents, false)?;// Calculate similarity
let similarities = model.similarity(&query_embeddings, &document_embeddings)?;
println!("Similarity score: {}", similarities.data[0][0]);// Use hierarchical pooling
let pooled_document_embeddings = hierarchical_pooling(&document_embeddings, 2)?;
let pooled_similarities = model.similarity(&query_embeddings, &pooled_document_embeddings)?;
println!("Similarity score after hierarchical pooling: {}", pooled_similarities.data[0][0]);Ok(())
}
```
## 📊 Benchmarks
```python
Device backend Queries per seconds Documents per seconds Model loading time
cpu PyLate 350.10 32.16 2.06
cpu pylate-rs 386.21 (+10%) 42.15 (+31%) 0.07 (-97%)cuda PyLate 2236.48 882.66 3.62
cuda pylate-rs 4046.88 (+81%) 976.23 (+11%) 1.95 (-46%)mps PyLate 580.81 103.10 1.95
mps pylate-rs 291.71 (-50%) 23.26 (-77%) 0.08 (-96%)
```Benchmark were run with Python. `pylate-rs` provide significant performance improvement, especially in scenarios requiring fast startup times. While on a Mac it takes up to 5 seconds to load a model with the Transformers backend and encode a single query, `pylate-rs` achieves this in just 0.11 seconds, making it ideal for low-latency applications. Don't expect `pylate-rs` to be much faster than `PyLate` to encode a lot of content at the same time as PyTorch is heavily optimized.
## 📦 Using Custom Models
`pylate-rs` is compatible with any model saved in the PyLate format, whether from the Hugging Face Hub or a local directory. PyLate itself is compatible with a wide range of models, including those from Sentence Transformers, Hugging Face Transformers, and custom models. So before using `pylate-rs`, ensure your model is saved in the PyLate format. You can easily convert and upload your own models using PyLate.
Pushing a model to the Hugging Face Hub in PyLate format is straightforward. Here’s how you can do it:
```bash
pip install pylate
```Then, you can use the following Python code snippet to push your model:
```python
from pylate import models# Load your model
model = models.ColBERT(model_name_or_path="your-base-model-on-hf")# Push in PyLate format
model.push_to_hub(
repo_id="YourUsername/YourModelName",
private=False,
token="YOUR_HUGGINGFACE_TOKEN",
)
```If you want to save a model in PyLate format locally, you can do so with the following code snippet:
```python
from pylate import models# Load your model
model = models.ColBERT(model_name_or_path="your-base-model-on-hf")# Save in PyLate format
model.save_pretrained("path/to/save/GTE-ModernColBERT-v1-pylate")
```An existing set of models compatible with `pylate-rs` is available on the Hugging Face Hub under the [**LightOn**](https://huggingface.co/collections/lightonai/pylate-6862b571946fe88330d65264) namespace.
## Retrieval pipeline
```bash
pip install pylate-rs fast-plaid
```Here is a sample code for running ColBERT with pylate-rs and fast-plaid.
```python
import torch
from fast_plaid import search
from pylate_rs import modelsmodel = models.ColBERT(
model_name_or_path="lightonai/GTE-ModernColBERT-v1",
device="cpu", # mps or cuda
)documents = [
"1st Arrondissement: Louvre, Tuileries Garden, Palais Royal, historic, tourist.",
"2nd Arrondissement: Bourse, financial, covered passages, Sentier, business.",
"3rd Arrondissement: Marais, Musée Picasso, galleries, trendy, historic.",
"4th Arrondissement: Notre-Dame, Marais, Hôtel de Ville, LGBTQ+.",
"5th Arrondissement: Latin Quarter, Sorbonne, Panthéon, student, intellectual.",
"6th Arrondissement: Saint-Germain-des-Prés, Luxembourg Gardens, chic, artistic, cafés.",
"7th Arrondissement: Eiffel Tower, Musée d'Orsay, Les Invalides, affluent, prestigious.",
"8th Arrondissement: Champs-Élysées, Arc de Triomphe, luxury, shopping, Élysée.",
"9th Arrondissement: Palais Garnier, department stores, shopping, theaters.",
"10th Arrondissement: Gare du Nord, Gare de l'Est, Canal Saint-Martin.",
"11th Arrondissement: Bastille, nightlife, Oberkampf, revolutionary, hip.",
"12th Arrondissement: Bois de Vincennes, Opéra Bastille, Bercy, residential.",
"13th Arrondissement: Chinatown, Bibliothèque Nationale, modern, diverse, street-art.",
"14th Arrondissement: Montparnasse, Catacombs, residential, artistic, quiet.",
"15th Arrondissement: Residential, family, populous, Parc André Citroën.",
"16th Arrondissement: Trocadéro, Bois de Boulogne, affluent, elegant, embassies.",
"17th Arrondissement: Diverse, Palais des Congrès, residential, Batignolles.",
"18th Arrondissement: Montmartre, Sacré-Cœur, Moulin Rouge, artistic, historic.",
"19th Arrondissement: Parc de la Villette, Cité des Sciences, canals, diverse.",
"20th Arrondissement: Père Lachaise, Belleville, cosmopolitan, artistic, historic.",
]# Encoding documents
documents_embeddings = model.encode(
sentences=documents,
is_query=False,
pool_factor=2, # Let's divide the number of embeddings by 2.
)# Creating the FastPlaid index
fast_plaid = search.FastPlaid(index="index")fast_plaid.create(
documents_embeddings=[torch.tensor(embedding) for embedding in documents_embeddings]
)
```We can then load the existing index and search for the most relevant documents:
```python
import torch
from fast_plaid import search
from pylate_rs import modelsfast_plaid = search.FastPlaid(index="index")
queries = [
"arrondissement with the Eiffel Tower and Musée d'Orsay",
"Latin Quarter and Sorbonne University",
"arrondissement with Sacré-Cœur and Moulin Rouge",
"arrondissement with the Louvre and Tuileries Garden",
"arrondissement with Notre-Dame Cathedral and the Marais",
]queries_embeddings = model.encode(
sentences=queries,
is_query=True,
)scores = fast_plaid.search(
queries_embeddings=torch.tensor(queries_embeddings),
top_k=3,
)print(scores)
```## 📝 Citation
If you use `pylate-rs` in your research or project, please cite it as follows:
```bibtex
@misc{PyLate,
title={PyLate: Flexible Training and Retrieval for Late Interaction Models},
author={Chaffin, Antoine and Sourty, Raphaël},
url={https://github.com/lightonai/pylate},
year={2024}
}
```
## WebAssembly
For JavaScript and TypeScript projects, install the WASM package from npm.
```bash
npm install pylate-rs
```Load the model by fetching the required files from a local path or the Hugging Face Hub.
```javascript
import { ColBERT } from "pylate-rs";const REQUIRED_FILES = [
"tokenizer.json",
"model.safetensors",
"config.json",
"config_sentence_transformers.json",
"1_Dense/model.safetensors",
"1_Dense/config.json",
"special_tokens_map.json",
];async function loadModel(modelRepo) {
const fetchAllFiles = async (basePath) => {
const responses = await Promise.all(
REQUIRED_FILES.map((file) => fetch(`${basePath}/${file}`))
);
for (const response of responses) {
if (!response.ok) throw new Error(`File not found: ${response.url}`);
}
return Promise.all(
responses.map((res) => res.arrayBuffer().then((b) => new Uint8Array(b)))
);
};try {
let modelFiles;
try {
// Attempt to load from a local `models` directory first
modelFiles = await fetchAllFiles(`models/${modelRepo}`);
} catch (e) {
console.warn(
`Local model not found, falling back to Hugging Face Hub.`,
e
);
// Fallback to fetching directly from the Hugging Face Hub
modelFiles = await fetchAllFiles(
`https://huggingface.co/${modelRepo}/resolve/main`
);
}const [
tokenizer,
model,
config,
stConfig,
dense,
denseConfig,
tokensConfig,
] = modelFiles;// Instantiate the model with the loaded files
const colbertModel = new ColBERT(
model,
dense,
tokenizer,
config,
stConfig,
denseConfig,
tokensConfig,
32
);// You can now use `colbertModel` for encoding
console.log("Model loaded successfully!");
return colbertModel;
} catch (error) {
console.error("Model Loading Error:", error);
}
}
```