An open API service indexing awesome lists of open source software.

https://github.com/dkobylianskii/torch-lap-cuda

A fast CUDA implementation of the Linear Assignment Problem (LAP) solver for PyTorch.
https://github.com/dkobylianskii/torch-lap-cuda

cuda python pytorch

Last synced: about 2 months ago
JSON representation

A fast CUDA implementation of the Linear Assignment Problem (LAP) solver for PyTorch.

Awesome Lists containing this project

README

          

# CUDA LAP Solver
[![PyPI version](https://badge.fury.io/py/torch-lap-cuda.svg)](https://badge.fury.io/py/torch-lap-cuda)
[![Downloads](https://static.pepy.tech/badge/torch-lap-cuda)](https://pepy.tech/project/torch-lap-cuda)
[![License](https://img.shields.io/badge/MIT-blue.svg)](https://opensource.org/licenses/MIT)



Installation |
Usage |
Benchmarks

A fast CUDA implementation of the Linear Assignment Problem (LAP) solver for PyTorch. This project provides GPU-accelerated HyLAC algorithm implementation that can efficiently handle batched inputs.

Based on the HyLAC code https://github.com/Nagi-Research-Group/HyLAC/tree/Block-LAP
Please cite the original work if you use this code in your research: https://doi.org/10.1016/j.jpdc.2024.104838

## Features

- Fast CUDA-based implementation of the LAP solver
- Batched processing support for multiple cost matrices
- Seamless integration with PyTorch
- Supports single and double precision types: `torch.int32, torch.int64, torch.float32, torch.float64`

## Requirements

- Python >= 3.9
- CUDA >= 10.0
- PyTorch
- NVIDIA GPU with compute capability >= 7.5

## Installation

To install the package, you can use pip:

```bash
pip install torch-lap-cuda --no-build-isolation
```

You can install the package directly from source:

```bash
git clone https://github.com/dkobylianskii/torch-lap-cuda.git
cd torch-lap-cuda
pip install . --no-build-isolation
```

## Usage

Here's a simple example of how to use the LAP solver:

```python
import torch
from torch_lap_cuda import solve_lap

# Create a random cost matrix (batch_size x N x N)
batch_size = 128
size = 256
cost_matrix = torch.randn((batch_size, size, size), device="cuda")

# Solve the assignment problem
# assignments shape will be (batch_size, size)
# Each batch element contains the column indices for optimal assignment
assignments = solve_lap(cost_matrix)

# Calculate total costs
batch_idxs = torch.arange(batch_size, device=assignments.device).unsqueeze(1)
row_idxs = torch.arange(size, device=assignments.device).unsqueeze(0)
total_cost = cost_matrix[batch_idxs, row_idxs, assignments].sum()
```

The solver also supports 2D inputs for single matrices:

```python
# Single cost matrix (N x N)
cost_matrix = torch.randn((size, size), device="cuda")
assignments = solve_lap(cost_matrix) # Shape: (size,)
```

In case of having multiple GPUs, you can specify the device for lap solver using the `device` argument:

```python
cost_matrix = torch.randn((batch_size, size, size), device="cuda:0")
assignments = solve_lap(cost_matrix, device="cuda:1") # assignments will be on cuda:0
```

## Input Requirements

- Cost matrices must be on a CUDA device
- Input can be either 2D (N x N) or 3D (batch_size x N x N)
- Matrices must be square
- Supports single and double precision types: `torch.int32, torch.int64, torch.float32, torch.float64`

## Benchmarks

Tests were performed on an INTEL(R) XEON(R) GOLD 6530 and NVIDIA A6000 Ada GPU with CUDA 12.5 and PyTorch 2.6.0.

`Scipy (MP)` means multiprocessing version, `Scipy (MT)` means multithreading version, both used 32 processes/threads.

To run the benchmarks, execute:

```bash
python tests/benchmark.py
```

### Benchmark for uniform random distribution:

![Benchmark results for uniform random cost matrices](figs/benchmark_uniform.png)

### Benchmark for normal random distribution:

![Benchmark results for normal random cost matrices](figs/benchmark_normal.png)

### Benchmark for integer random distribution:

![Benchmark results for integer random cost matrices](figs/benchmark_integer.png)

## Testing

To run the test suite:

```bash
pytest tests/
```