Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/shchur/survival_distributions
Extended functionality for univariate probability distributions in PyTorch
https://github.com/shchur/survival_distributions
probabilistic-modeling pytorch survival-analysis temporal-point-processes torch-distributions
Last synced: about 1 month ago
JSON representation
Extended functionality for univariate probability distributions in PyTorch
- Host: GitHub
- URL: https://github.com/shchur/survival_distributions
- Owner: shchur
- License: mit
- Created: 2021-12-21T14:03:30.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2022-06-28T08:43:38.000Z (over 2 years ago)
- Last Synced: 2024-11-30T10:07:27.076Z (about 1 month ago)
- Topics: probabilistic-modeling, pytorch, survival-analysis, temporal-point-processes, torch-distributions
- Language: Python
- Homepage:
- Size: 52.7 KB
- Stars: 10
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# survival_distributions
This package extends the functionality of univariate distributions in [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html)
by implementing several new methods:
- `sf`: survival function (complementary CDF)
- `logsf`: logarithm of the survival function (negative cumulative hazard function)
- `logcdf`: logarithm of the CDF
- `log_hazard`: logarithm of the hazard function (logarithm of the failure rate)
- `isf`: inverse of the survival function
- `sample_cond`: instead of sampling from the full support of the distribution,
generate samples between `lower_bound` and `upper_bound`
This is especially useful when working with
[temporal point processes](https://shchur.github.io/blog/2020/tpp1-conditional-intensity/)
or [survival analysis](https://en.wikipedia.org/wiki/Survival_analysis).Naive implementation based on existing PyTorch functionality (e.g.,
`torch.log(1.0 - dist.cdf(x))` for `logsf`) will often not be as accurate and numerically
stable as the implementation provided by `survival_distributions`.
Hopefully, these methods will be implemented in PyTorch [sometime in the future](https://github.com/pytorch/pytorch/issues/52973),
but this package provides an alternative for the time being.See [`DISTRIBUTIONS.md`](https://github.com/shchur/survival_distributions/blob/main/DISTRIBUTIONS.md) for more details about the implemented functions and supported distributions.
## Installation
1. Install the latest version of [PyTorch](https://pytorch.org/get-started/locally/).
2. Install `survival_distributions`
```bash
pip install survival_distributions
```## Supported distributions
### Numerically stable implementation
For these distributions we provide a numerically stable implementation of `logsf`.
- `Exponential`
- `Logistic`
- `LogLogistic`
- `MixtureSameFamily`
- `TransformedDistribution`
- `Uniform`
- `Weibull`### Naive implementation
For these distributions we implement `logsf(x)` as `log(1.0 - dist.cdf(x))`, which is less
numerically stable.
- `LogNormal`
- `Normal`## Usage
The package provides a drop-in replacement for `torch.distributions`, so you can just modify your code as follows.**Old code**
```python
import torchdist = torch.distributions.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)log_survival_proba = torch.log(1.0 - dist.cdf(x))
```
**New code**
```python
import torch
import survival_distributions as sddist = sd.Exponential(rate=torch.tensor(2.0))
x = torch.tensor(1.5)log_survival_proba = dist.logsf(x)
```