https://github.com/shreyansh26/accelerating-cross-encoder-inference
Leveraging torch.compile to accelerate cross-encoder inference
https://github.com/shreyansh26/accelerating-cross-encoder-inference
cross-encoder inference-optimization jina mlsys torch-compile
Last synced: 2 months ago
JSON representation
Leveraging torch.compile to accelerate cross-encoder inference
- Host: GitHub
- URL: https://github.com/shreyansh26/accelerating-cross-encoder-inference
- Owner: shreyansh26
- Created: 2025-02-21T16:35:51.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2025-03-02T15:53:59.000Z (3 months ago)
- Last Synced: 2025-03-02T16:34:38.050Z (3 months ago)
- Topics: cross-encoder, inference-optimization, jina, mlsys, torch-compile
- Language: Python
- Homepage: https://shreyansh26.github.io/post/2025-03-02_cross-encoder-inference-torch-compile/
- Size: 3.83 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Accelerating Cross Encoder inference with torch.compile
## Overview
This project demonstrates optimizing the inference of a Cross Encoder model, namely [jinaai/jina-reranker-v2-base-multilingual](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual), by leveraging torch.compile. The scripts compare baseline performance against a torch.compile-optimized version using a custom padding approach.**Blog post describing the approach in detail** - https://shreyansh26.github.io/post/2025-03-02_cross-encoder-inference-torch-compile/
## Setup
- Python 3.8+
- PyTorch with CUDA support
- Sentence Transformers library
- Model: jinaai/jina-reranker-v2-base-multilingual## Scripts
### bench_basic.py
Runs a baseline benchmark with the standard CrossEncoder and Flash Attention enabled.### bench_torch_compile.py
Focuses on the torch.compile approach with some custom padding and torch.compile optimizations.### bench_combined.py
Compares the baseline with torch.compile optimized version.## Implementation Details
- **Batching with custom padding**: The custom `DynamicCrossEncoder` pads tokenized inputs to a bucket length (multiples of 16), to lower the number of dynamic lengths that torch.compile has to capture.
- **Sorted Inputs**: Sorting the inputs before batching allows the sequences in the batch to be of similar lengths hence less padding tokens to be processed.
- **torch.compile**: The model's forward function is compiled using `torch.compile` with the `inductor` backend, enabling dynamic shape handling and reducing latency.
## Speedup Analysis
The torch.compile optimized version shows significant speedups compared to the baseline (batch size 64, H100 GPU):
| Setup | Sorted Inputs (s) | Unsorted Inputs (s) |
| ------------------------------------------ | ----------------------- | ------------------------ |
| Base (with Flash Attention) | 0.2658 ± 0.0119 | 0.2961 ± 0.0089 |
| torch.compile | 0.2089 ± 0.0196 | 0.2595 ± 0.0077 |This reflects roughly a 20-25% reduction in inference latency under sorted inputs, with similar gains observed for unsorted inputs.
## How to Run
1. Ensure your environment has CUDA and the required libraries installed:
- `pip install sentence-transformers torch`
2. Execute the benchmark scripts:
- `CUDA_VISIBLE_DEVICES=0 python bench_basic.py`
- `CUDA_VISIBLE_DEVICES=0 python bench_torch_compile.py`
- `CUDA_VISIBLE_DEVICES=0 python bench_combined.py`