Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/adibender/pem.xgb
A preliminary prototype of piece-wise exponential models via XGBoost
https://github.com/adibender/pem.xgb
Last synced: 1 day ago
JSON representation
A preliminary prototype of piece-wise exponential models via XGBoost
- Host: GitHub
- URL: https://github.com/adibender/pem.xgb
- Owner: adibender
- License: other
- Created: 2020-04-15T09:49:39.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2021-09-04T08:30:38.000Z (over 3 years ago)
- Last Synced: 2024-12-20T23:51:35.531Z (2 days ago)
- Language: R
- Homepage:
- Size: 43 KB
- Stars: 2
- Watchers: 2
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.Rmd
- License: LICENSE
Awesome Lists containing this project
README
# pam.xgb
Survival analysis with xgboost
(Prototype, will have to think more about design (also **`pammtools`**) +
tunning, etc.)## Example
```{r}
# setup
# remotes::install_github("adibender/pammtools", ref = "master")
library(pammtools) # requires github version
library(dplyr)
library(mgcv)
devtools::load_all()library(ggplot2)
theme_set(theme_bw())
```### Data Simulation
```{r}
# set number of observations/subjects
n <- 500
# create data set with variables which will affect the hazard rate.
df <- cbind.data.frame(x1 = runif (n, -3, 3), x2 = runif (n, 0, 6)) %>%
as_tibble()
# the formula which specifies how covariates affet the hazard rate
f0 <- function(t) {
dgamma(t, 8, 2) * 6
}
form <- ~ - 3.5 + f0(t) - 0.5 * x1 + sqrt(x2)sim_df <- sim_pexp(form, df, seq(0, 10, by = .2)) %>%
mutate(time = round(time, 6))
head(sim_df)
# plot(survival::survfit(survival::Surv(time, status)~1, data = sim_df ))
```### Model Fitting
```{r}
# transform "standard" survival data to PED format
ped <- sim_df %>% as_ped(formula = Surv(time, status) ~ .)# fit PAM
pam <- gam(ped_status ~ s(tend, k = 20) + x1 + s(x2), data = ped,
family = "poisson", offset = offset, method = "REML")
summary(pam)# fit xgboost model
xgb_params <- list(
max_depth = 3,
eta = .01,
colsample_bytree = .7,
min_child_weight = 10,
subsample = .7)xgb_pam <- xgboost.ped(
ped,
params = xgb_params,
nrounds = 1500,
print_every_n = 500)
```### Visualization
```{r, fig.width = 6, fig.height = 3}
## add hazard and surv_prob for xgb model
pred_df <- ped %>% make_newdata(tend = unique(ped$tend), x1 = c(-.5, 1))
pred_df <- pred_df %>%
add_hazard2(xgb_pam) %>%
add_hazard(pam)
## add hazard and surv_prob for pam
pred_df <- pred_df %>%
group_by(x1) %>%
add_surv_prob2(xgb_pam) %>%
add_surv_prob(pam)
# add true hazard and survival probability
pred_df <- pred_df %>%
mutate(
truth_hazard = exp(-3.5 + f0(tend) -0.5 * x1 + sqrt(x2)),
truth = exp(-cumsum(truth_hazard * intlen)))# Hazard
ggplot(pred_df, aes(x = tend)) +
geom_stephazard(aes(y = hazard_pam_xgb), col = "green") +
geom_stephazard(aes(y = hazard), col = "black") +
geom_line(aes(y = truth_hazard), col = "blue", lty = 3) +
facet_wrap(~x1)# Survival probability
ggplot(pred_df, aes(x = tend)) +
geom_line(aes(y = surv_prob_pam_xgb), col = "green") +
geom_line(aes(y = surv_prob), col = "black") +
geom_line(aes(y = truth), col = "blue", lty = 3) +
scale_y_continuous(limits = c(0, 1)) +
facet_wrap(~x1)
```## Benchmarking (illustration)
```{r}
library(survival)
library(obliqueRSF)
data("tumor", package = "pammtools")## estimate models for benchmarking (on subset of data)
# xgboost
xgb_params <- list(
max_depth = 3,
eta = .1,
colsample_bytree = .7,
min_child_weight = 10,
subsample = .7)train_idx <- sample(seq_len(nrow(tumor)), 600)
test_idx <- setdiff(seq_len(nrow(tumor)), train_idx)
## pam xgboost model
pxgb <- xgboost.ped(
data = as_ped(tumor[train_idx, ], Surv(days, status) ~ .),
params = xgb_params,
nrounds = 400,
print_every_n = 500L)
## cox ph
mcox <- coxph(Surv(days, status)~., data = tumor[train_idx, ], x = TRUE)
## oblique RSF
orsf <- ORSF(tumor[train_idx,], time = "days", verbose = FALSE)
``````{r, fig.width = 5, fig.height = 5}
## Evaluation
library(pec)
# prediction error curve (Brier Score)
pec1 <- pec(
list(pam_xgb = pxgb, coxph = mcox, orsf = orsf),
Surv(days, status) ~ 1,
exact = FALSE,
data = tumor[test_idx, ],
times = seq(1, 2500, by = 10),
start = 1)
plot(pec1)
## Integrated Brier Score at 4 time points
crps(pec1, times = c(500, 1000, 1500, 2000, 2500))
```