Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/google-research/fast-soft-sort
Fast Differentiable Sorting and Ranking
https://github.com/google-research/fast-soft-sort
differentiable jax pytorch ranking sorting tensorflow
Last synced: 3 days ago
JSON representation
Fast Differentiable Sorting and Ranking
- Host: GitHub
- URL: https://github.com/google-research/fast-soft-sort
- Owner: google-research
- License: apache-2.0
- Created: 2020-06-08T22:28:01.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2024-02-15T03:26:44.000Z (9 months ago)
- Last Synced: 2024-05-09T17:13:01.658Z (6 months ago)
- Topics: differentiable, jax, pytorch, ranking, sorting, tensorflow
- Language: Python
- Homepage:
- Size: 35.2 KB
- Stars: 546
- Watchers: 14
- Forks: 45
- Open Issues: 15
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
Fast Differentiable Sorting and Ranking
=======================================Differentiable sorting and ranking operations in O(n log n).
Dependencies
------------* NumPy
* SciPy
* Numba
* Tensorflow (optional)
* PyTorch (optional)TensorFlow Example
-------------------```python
>>> import tensorflow as tf
>>> from fast_soft_sort.tf_ops import soft_rank, soft_sort
>>> values = tf.convert_to_tensor([[5., 1., 2.], [2., 1., 5.]], dtype=tf.float64)
>>> soft_sort(values, regularization_strength=1.0)>>> soft_sort(values, regularization_strength=0.1)
>>> soft_rank(values, regularization_strength=2.0)
>>> soft_rank(values, regularization_strength=1.0)
```
JAX Example
-----------```python
>>> import jax.numpy as jnp
>>> from fast_soft_sort.jax_ops import soft_rank, soft_sort
>>> values = jnp.array([[5., 1., 2.], [2., 1., 5.]], dtype=jnp.float64)
>>> soft_sort(values, regularization_strength=1.0)
[[1.66666667 2.66666667 3.66666667]
[1.66666667 2.66666667 3.66666667]]
>>> soft_sort(values, regularization_strength=0.1)
[[1. 2. 5.]
[1. 2. 5.]]
>>> soft_rank(values, regularization_strength=2.0)
[[3. 1.25 1.75]
[1.75 1.25 3. ]]
>>> soft_rank(values, regularization_strength=1.0)
[[3. 1. 2.]
[2. 1. 3.]]
```PyTorch Example
---------------```python
>>> import torch
>>> from pytorch_ops import soft_rank, soft_sort
>>> values = fast_soft_sort.torch.tensor([[5., 1., 2.], [2., 1., 5.]], dtype=torch.float64)
>>> soft_sort(values, regularization_strength=1.0)
tensor([[1.6667, 2.6667, 3.6667]
[1.6667, 2.6667, 3.6667]], dtype=torch.float64)
>>> soft_sort(values, regularization_strength=0.1)
tensor([[1., 2., 5.]
[1., 2., 5.]], dtype=torch.float64)
>>> soft_rank(values, regularization_strength=2.0)
tensor([[3.0000, 1.2500, 1.7500],
[1.7500, 1.2500, 3.0000]], dtype=torch.float64)
>>> soft_rank(values, regularization_strength=1.0)
tensor([[3., 1., 2.]
[2., 1., 3.]], dtype=torch.float64)
```Install
--------Run `python setup.py install` or copy the `fast_soft_sort/` folder to your
project.Reference
------------> Fast Differentiable Sorting and Ranking
> Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
> In proceedings of ICML 2020
> [arXiv:2002.08871](https://arxiv.org/abs/2002.08871)