Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/mayer79/partialplot
Partial dependency plots in R for xgboost, lightGBM and ranger objects
https://github.com/mayer79/partialplot
Last synced: 3 months ago
JSON representation
Partial dependency plots in R for xgboost, lightGBM and ranger objects
- Host: GitHub
- URL: https://github.com/mayer79/partialplot
- Owner: mayer79
- Created: 2017-08-19T12:16:49.000Z (over 7 years ago)
- Default Branch: master
- Last Pushed: 2023-09-26T16:30:55.000Z (over 1 year ago)
- Last Synced: 2024-10-04T12:56:52.147Z (3 months ago)
- Language: R
- Homepage:
- Size: 250 KB
- Stars: 8
- Watchers: 4
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# partialPlot
Partial dependency plots for R objects of type XGBoost, lightGBM and ranger## Idea
The R function `partialPlot` is used to visualize partial dependency of the response on a covariable. It is inspired by the analogous function in the `randomForest` package and works as long as `predict` returns numeric values (no classes!).The main arguments of `partialPlot` are as follows
1. `obj`: model object of type `lgb.Booster`, `xgb.Booster` or `ranger`
2. `pred.data`: Matrix to be used in prediction (no special objects like `xgb.DMatrix` or `lgb.Dataset`)
3. `xname`: Name of column in `pred.data` according to that dependency plot is calculated
4. `n.pt`: Evaluation grid size (used only if `x` is not discrete). Quantile cuts are used.
5. `x.discrete`: If TRUE, the evaluation grid is set to the unique values of `x`
6. `subsample`: Fraction of lines in `pred.data` to be used in prediction
7. `which.class`: Which class if objective is "multi:softprob" (value from 0 to num_class - 1)## The function
Check R/partialPlot.R for parameters etc.## Examples
### Example 1: Regression (realistic example based on diamonds data set)```
library(ggplot2) # for data set "diamonds"
library(xgboost)
source("R/partialPlot.R") # or your path#======================================================================
# Data prep
#======================================================================diamonds <- transform(as.data.frame(diamonds),
log_price = log(price),
log_carat = log(carat),
cut = as.numeric(cut),
color = as.numeric(color),
clarity = as.numeric(clarity))# Train/test split
set.seed(3928272)
.in <- sample(c(FALSE, TRUE), nrow(diamonds), replace = TRUE, p = c(0.15, 0.85))x <- c("log_carat", "cut", "color", "clarity", "depth", "table")
train <- list(y = diamonds$log_price[.in],
X = as.matrix(diamonds[.in, x]))
test <- list(y = diamonds$log_price[!.in],
X = as.matrix(diamonds[!.in, x]))#======================================================================
# Small functions
#======================================================================# Calculate R squared
r2 <- function(y, pred) {
1 - var(y - pred) / var(y)
}# Show all partial dependency plots
partialDiamondsPlot <- function(fit) {
par(mfrow = 3:2,
oma = c(0, 0, 0, 0) + 0.3,
mar = c(4, 2, 0, 0) + 0.1,
mgp = c(2, 0.5, 0.5))
partialPlot(fit, train$X, xname = "log_carat")
partialPlot(fit, train$X, xname = "cut", discrete.x = TRUE)
partialPlot(fit, train$X, xname = "color", discrete.x = TRUE)
partialPlot(fit, train$X, xname = "clarity", discrete.x = TRUE)
partialPlot(fit, train$X, xname = "depth")
partialPlot(fit, train$X, xname = "table")
}#======================================================================
# xgboost regression
#======================================================================dtrain <- xgb.DMatrix(train$X, label = train$y)
dtest <- xgb.DMatrix(test$X, label = test$y)
watchlist <- list(train = dtrain, test = dtest)param <- list(max_depth = 8,
learning_rate = 0.01,
nthread = 2,
lambda = 0.2,
objective = "reg:linear",
eval_metric = "rmse",
subsample = 0.7)fit_xgb <- xgb.train(param, dtrain, watchlist = watchlist,
nrounds = 850, early_stopping_rounds = 5)
r2(train$y, predict(fit_xgb, train$X)) # 0.9927861
r2(test$y, predict(fit_xgb, test$X)) # 0.9912827partialDiamondsPlot(fit_xgb)
```![Diamonds plot](/pics/diamonds.jpeg)
### Example 2: Multiclass prediction (toy example based on iris data set)
```
train <- list(y = as.numeric(iris[, 5]),
X = as.matrix(iris[, 1:4]))dtrain <- xgb.DMatrix(train$X, label = as.numeric(train$y) - 1)
param <- list(max_depth = 2, learning_rate = 0.1, objective = "multi:softprob",
num_class = 3, eval_metric = "merror")fit_xgb <- xgb.train(dtrain, params = param, nrounds = 100)
par(mfrow = c(2, 2))
for (nam in colnames(train$X)) {
partialPlot(fit_xgb, train$X, xname = nam, xlab = "", which.class = 0)
}
```The effects on species "setosa" (first class, corresponding to level 0) are as follows:
![iris plot](/pics/iris.jpeg)