Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/mayer79/marginalplot
Beautiful marginal plots for modeling
https://github.com/mayer79/marginalplot
machine-learning r xai
Last synced: 3 months ago
JSON representation
Beautiful marginal plots for modeling
- Host: GitHub
- URL: https://github.com/mayer79/marginalplot
- Owner: mayer79
- License: gpl-2.0
- Created: 2024-09-21T13:22:04.000Z (4 months ago)
- Default Branch: main
- Last Pushed: 2024-10-21T18:31:49.000Z (3 months ago)
- Last Synced: 2024-10-22T08:44:16.425Z (3 months ago)
- Topics: machine-learning, r, xai
- Language: R
- Homepage: https://mayer79.github.io/marginalplot/
- Size: 1.31 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Changelog: NEWS.md
- License: LICENSE.md
Awesome Lists containing this project
README
# marginalplot
[![R-CMD-check](https://github.com/mayer79/marginalplot/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/mayer79/marginalplot/actions/workflows/R-CMD-check.yaml)
[![Lifecycle: maturing](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://www.tidyverse.org/lifecycle/#experimental)
[![Codecov test coverage](https://codecov.io/gh/mayer79/marginalplot/graph/badge.svg)](https://app.codecov.io/gh/mayer79/marginalplot)**{marginalplot}** provides high-quality plots for modeling.
Per feature and feature value, the main function `marginal()` calculates
- average observed values of the model response,
- average predicted values,
- partial dependence, and
- the exposure.The workflow is as follows:
1. Crunch values via `marginal()` or the convenience wrappers `average_observed()` and `partial_dependence()`.
2. Post-process the results with `postprocess()`, e.g., to collapse rare levels of a categorical feature.
3. Plot the results with `plot()`.**Notes**
- You can switch between {ggplot2}/{patchwork} plots and interactive {plotly} plots.
- The implementation is optimized for speed and convenience.
- Most models (including DALEX explainers and meta-learners such as Tidymodels) work out-of-the box. If not, a tailored prediction function can be specified.
- For multioutput models, the last output is picked.
- Case weights are supported via the argument `w`.
- Binning of numeric X is done by the same options as `stats::hist()`. Additionally, very small and large values are winsorized (clipped) by default.## Installation
You can install the development version of {marginalplot} from [GitHub](https://github.com/) with:
``` r
# install.packages("pak")
pak::pak("mayer79/marginalplot")
```## Usage
``` r
library(marginalplot)
library(ranger)set.seed(1)
fit <- ranger(Sepal.Length ~ ., data = iris)
xvars <- c("Sepal.Width", "Petal.Width", "Petal.Length", "Species")marginal(fit, v = xvars, data = iris, y = "Sepal.Length", breaks = "Scott") |>
plot(num_points = TRUE)
```![](man/figures/marginal1.svg)
## More examples
### Partial dependence only
The function `partial_dependence()` produces high-quality plots to study main effects. To visually see how important each feature is (regarding main effect strength), we activate the option `share_y` and sort the plots by decreasing variance of the partial dependence function (exposure weighted).
``` r
library(marginalplot)
library(ranger)set.seed(1)
fit <- ranger(Sepal.Length ~ ., data = iris)
xvars <- colnames(iris)[-1]partial_dependence(fit, v = xvars, data = iris, breaks = 17) |>
plot(sort = TRUE, share_y = TRUE, scale_exposure = 0.2)
```![](man/figures/pd.svg)
### Before modeling
Before modeling, you might be interested in
- univariate distributions of potential features, and
- how the average response is associated with their values.These infos are provided via `average_observed()`.
Note: Sorting is done by decreasing variance of average observed values (exposure weighted).
``` r
library(marginalplot)xvars <- colnames(iris)[-1]
average_observed(xvars, data = iris, y = "Sepal.Length", breaks = 5) |>
plot(sort = TRUE, share_y = TRUE, rotate_x = 45)
```![](man/figures/avg_obs.svg)