Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/google-research/fast-soft-sort

Fast Differentiable Sorting and Ranking
https://github.com/google-research/fast-soft-sort

differentiable jax pytorch ranking sorting tensorflow

Last synced: 3 days ago
JSON representation

Fast Differentiable Sorting and Ranking

Awesome Lists containing this project

README

        

Fast Differentiable Sorting and Ranking
=======================================

Differentiable sorting and ranking operations in O(n log n).

Dependencies
------------

* NumPy
* SciPy
* Numba
* Tensorflow (optional)
* PyTorch (optional)

TensorFlow Example
-------------------

```python
>>> import tensorflow as tf
>>> from fast_soft_sort.tf_ops import soft_rank, soft_sort
>>> values = tf.convert_to_tensor([[5., 1., 2.], [2., 1., 5.]], dtype=tf.float64)
>>> soft_sort(values, regularization_strength=1.0)

>>> soft_sort(values, regularization_strength=0.1)

>>> soft_rank(values, regularization_strength=2.0)

>>> soft_rank(values, regularization_strength=1.0)

```

JAX Example
-----------

```python
>>> import jax.numpy as jnp
>>> from fast_soft_sort.jax_ops import soft_rank, soft_sort
>>> values = jnp.array([[5., 1., 2.], [2., 1., 5.]], dtype=jnp.float64)
>>> soft_sort(values, regularization_strength=1.0)
[[1.66666667 2.66666667 3.66666667]
[1.66666667 2.66666667 3.66666667]]
>>> soft_sort(values, regularization_strength=0.1)
[[1. 2. 5.]
[1. 2. 5.]]
>>> soft_rank(values, regularization_strength=2.0)
[[3. 1.25 1.75]
[1.75 1.25 3. ]]
>>> soft_rank(values, regularization_strength=1.0)
[[3. 1. 2.]
[2. 1. 3.]]
```

PyTorch Example
---------------

```python
>>> import torch
>>> from pytorch_ops import soft_rank, soft_sort
>>> values = fast_soft_sort.torch.tensor([[5., 1., 2.], [2., 1., 5.]], dtype=torch.float64)
>>> soft_sort(values, regularization_strength=1.0)
tensor([[1.6667, 2.6667, 3.6667]
[1.6667, 2.6667, 3.6667]], dtype=torch.float64)
>>> soft_sort(values, regularization_strength=0.1)
tensor([[1., 2., 5.]
[1., 2., 5.]], dtype=torch.float64)
>>> soft_rank(values, regularization_strength=2.0)
tensor([[3.0000, 1.2500, 1.7500],
[1.7500, 1.2500, 3.0000]], dtype=torch.float64)
>>> soft_rank(values, regularization_strength=1.0)
tensor([[3., 1., 2.]
[2., 1., 3.]], dtype=torch.float64)
```

Install
--------

Run `python setup.py install` or copy the `fast_soft_sort/` folder to your
project.

Reference
------------

> Fast Differentiable Sorting and Ranking
> Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
> In proceedings of ICML 2020
> [arXiv:2002.08871](https://arxiv.org/abs/2002.08871)