An open API service indexing awesome lists of open source software.

https://github.com/dandls/counterfactuals

counterfactuals: An R package for Counterfactual Explanation Methods
https://github.com/dandls/counterfactuals

interpretable-machine-learning local-explanations model-agnostic-explanations

Last synced: 4 months ago
JSON representation

counterfactuals: An R package for Counterfactual Explanation Methods

Awesome Lists containing this project

README

          

---
output: github_document
---

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
fig.height = 3,
fig.width = 8
)
options(width = 200)
set.seed(12345)
```

# counterfactuals

[![R-CMD-check](https://github.com/dandls/counterfactuals/workflows/R-CMD-check/badge.svg)](https://github.com/dandls/counterfactuals/actions)

The `counterfactuals` package provides various (model-agnostic) counterfactual explanation methods via a unified R6-based interface.

Counterfactual explanation methods address questions of the form:
"For input $\mathbf{x^{\star}}$, the model predicted $y$. What needs to be changed in $\mathbf{x^{\star}}$ for the model
to predict a desired outcome $\tilde{y}$ instead?". \
Denied loan applications serve as a common example; here a counterfactual explanation (or counterfactual for short) could be:
"The loan was denied because the amount of €30k is too high
given the income. If the amount had been €20k, the loan would have been granted."

For an introduction to counterfactual explanation methods, we recommend Chapter 9.3 of the [Interpretable Machine Learning book](https://christophm.github.io/interpretable-ml-book/) by Christoph Molnar. The package is based on the R code underlying the paper [Multi-Objective Counterfactual Explanations
(MOC)](https://link.springer.com/chapter/10.1007/978-3-030-58112-1_31).

## Available methods

The following counterfactual explanation methods are currently implemented:

- [Multi-Objective Counterfactual Explanations (MOC)](https://link.springer.com/chapter/10.1007/978-3-030-58112-1_31)
- [Nearest Instance Counterfactual Explanations (NICE)](https://arxiv.org/abs/2104.07411) (an extended version)
- [WhatIf](https://arxiv.org/abs/1907.04135) (an extended version)

## Installation

You can install the development version from [GitHub](https://github.com/) with:

``` r
# install.packages("devtools")
devtools::install_github("dandls/counterfactuals")
```
## Get started

In this example, we train a `randomForest` on the `iris` dataset and examine how a given `virginica` observation
would have to change in order to be classified as `versicolor`.

```{r example, message=FALSE}
library(counterfactuals)
library(randomForest)
library(iml)
```
### Fitting a model

First, we train a `randomForest` model to predict the target variable `Species`, omitting one observation from the training
data, which is `x_interest` (the observation $x^{\star}$ for which we want to find counterfactuals).
```{r}
rf = randomForest(Species ~ ., data = iris[-150L, ])
```

### Setting up an iml::Predictor() object

We then create an [`iml::Predictor`](https://christophm.github.io/iml/reference/Predictor.html) object, which serves as
a wrapper for different model types; it contains the model and the data for its analysis.

```{r}
predictor = Predictor$new(rf, type = "prob")
```

### Find counterfactuals

For `x_interest`, the model predicts a probability of 8% for class `versicolor`.

```{r}
x_interest = iris[150L, ]
predictor$predict(x_interest)
```

Now, we examine what needs to be changed in `x_interest` so that the model predicts a probability of at least 50% for class `versicolor`.

Here, we want to apply WhatIf and since it is a classification task, we create a `WhatIfClassif` object.

```{r}
wi_classif = WhatIfClassif$new(predictor, n_counterfactuals = 5L)
```

Then, we use the `find_counterfactuals()` method to find counterfactuals for `x_interest`.

```{r}
cfactuals = wi_classif$find_counterfactuals(
x_interest, desired_class = "versicolor", desired_prob = c(0.5, 1)
)
```

### The counterfactuals object

`cfactuals` is a `Counterfactuals` object that contains the counterfactuals and has several methods for their
evaluation and visualization.

```{r}
cfactuals
```

The counterfactuals are stored in the `data` field.
```{r}
cfactuals$data
```

With the `evaluate()` method, we can evaluate the counterfactuals using various quality measures.
```{r}
cfactuals$evaluate()
```

One visualization option is to plot the frequency of feature changes across all counterfactuals using the
`plot_freq_of_feature_changes()` method.

```{r, fig.height=2}
cfactuals$plot_freq_of_feature_changes()
```

Another visualization option is a parallel plot---created with the `plot_parallel()` method---that connects the (scaled)
feature values of each counterfactual and highlights `x_interest` in blue.

```{r, fig.height=2.5, message=FALSE}
cfactuals$plot_parallel()
```