https://github.com/soda-inria/survival-analysis-benchmark
Exploratory repository to study predictive survival analysis models
https://github.com/soda-inria/survival-analysis-benchmark
Last synced: about 1 month ago
JSON representation
Exploratory repository to study predictive survival analysis models
- Host: GitHub
- URL: https://github.com/soda-inria/survival-analysis-benchmark
- Owner: soda-inria
- License: mit
- Created: 2022-10-07T18:58:21.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2023-05-15T07:34:34.000Z (about 2 years ago)
- Last Synced: 2025-03-24T13:21:23.757Z (about 2 months ago)
- Language: Jupyter Notebook
- Size: 24.9 MB
- Stars: 34
- Watchers: 8
- Forks: 6
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Benchmarking predictive survival analysis models
This repository is dedicated to the evaluation of predictive survival models on large-ish datasets.
## Software dependencies
To run the jupytercon-2023 tutorial notebooks, you will need:
```
conda create -n jupytercon-survival -c conda-forge python jupyterlab scikit-learn lifelines scikit-survival matplotlib-base plotly seaborn pandas pyarrow ibis-duckdb polarsconda activate jupytercon-survival
jupyter lab
```## Notebooks
The `notebooks` folder holds the two main notebooks for the jupytercon-2023, namely:
- `tutorial_part_1.ipynb`
- `tutorial_part_2.ipynb`and the ancillary notebook used to generate the dataset used in "part 1", namely:
- `truck_dataset.ipynb`
Note that running `truck_dataset.ipynb` consumes a significant share of RAM, so you might prefer [downloading the datasets from zip](https://github.com/soda-inria/survival-analysis-benchmark/releases/download/jupytercon-2023-tutorial-data/truck_failure.zip) (500 MB) instead of generating them.
The notebooks display our benchmark results and show how to use our wrappers to cross validate various models.
- `kkbox_cv_benchmark.ipynb`
Benchmark of the KKBox challenge inspired from the [pycox paper](https://jmlr.org/papers/volume20/18-424/18-424.pdf).
- `msk_mettropism.ipynb`
Exploration of the MSK cancer dataset and survival probability predictions using our models.## Datasets
### WSDM - KKBox's Churn Prediction Challenge (from Kaggle)
The `datasets/kkbox_churn` folder contains Python code to efficiently
preprocess the raw transaction logs of the KKBox's Churn Prediction Challenge
using [ibis](https://ibis-project.org) and [duckdb](https://duckdb.org).- https://www.kaggle.com/competitions/kkbox-churn-prediction-challenge/data
The objectives are to:
- make everything reproducible from the event-based logs;
- implement efficient, parallel and out-of-core "sessionization" of the past
transactions for all members: here is a "session" is an uninterrupted
sequence of transactions;
- implement efficient, parallel and out-of-core tabularization (feature
and churn target with censoring);
- make it possible to compute the cumulative state of the subscription data
and the censored churn events at any point in time.Alternatively, `kkbox_cv_benchmark.ipynb` directly uses the preprocessing steps from pycox to ensure reproducibility, at the cost of memory and speed performances.
## Models
This repository introduces a novel survival estimator named Gradient Boosting CIF. This estimator is based on the HistGradientBoostingClassifier of scikit-learn under the hood.
It is named CIF because it has the capability to estimate cause-specific Cumulative Incidence Functions in a competing risks setting by minimizing a cause specific Integrated Brier Score (IBS) objective function:
```python
from models.gradient_boosted_cif import GradientBoostedCIFX_train = np.array([
[5.0, 0.1, 2.0],
[3.0, 1.1, 2.2],
[2.0, 0.3, 1.1],
[4.0, 1.0, 0.9],
])
y_train_multi_event = np.array([
(2, 33.2),
(0, 10.1),
(0, 50.0),
(1, 20.0),
], dtype=[("event", np.bool_), ("duration", np.float64)]
)
time_grid = np.linspace(0.0, 30.0, 10)gb_cif = GradientBoostedCIF(event_of_interest=1)
gb_cif.fit(X_train, y_train_multi_event, time_grid)X_test = np.array([[3.0, 1.0, 9.0]])
cif_curves = gb_cif.predict_cumulative_incidence(X_test, time_grid)
```Alternatively, you can estimate the probability of an event to be experienced at a specific time horizon:
```python
gb_cif.predict_proba(X_test, time_horizon=20)
```Conversely, you can estimate the conditional quantile time to event e.g. answering the question "at which time horizon does the CIF reach 50%?":
```python
gb_cif.predict_quantile(X_test, quantile=0.5)
```You can also estimate the survival function in the single event setting (binary event). Warning: this metric only makes sense when `y` is binary or when setting `event_of_interest='any'`.
```python
y_train_single_event = np.array([
(1, 12.0),
(0, 5.1),
(1, 1.1),
(0, 29.0),
], dtype=[("event", np.bool_), ("duration", np.float64)]
)
gb_cif.fit(X_train, y_train_single_event, time_grid)
survival_curves = gb_cif.predict_survival_function(X_test, time_grid)
```