https://github.com/massimoaria/e2tree
Explainable Ensemble Trees
https://github.com/massimoaria/e2tree
explainable-machine-learning
Last synced: 8 days ago
JSON representation
Explainable Ensemble Trees
- Host: GitHub
- URL: https://github.com/massimoaria/e2tree
- Owner: massimoaria
- License: other
- Created: 2022-10-13T16:14:46.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2026-03-21T14:32:37.000Z (20 days ago)
- Last Synced: 2026-03-22T03:53:22.866Z (19 days ago)
- Topics: explainable-machine-learning
- Language: R
- Homepage:
- Size: 9.52 MB
- Stars: 8
- Watchers: 4
- Forks: 3
- Open Issues: 0
-
Metadata Files:
- Readme: README.Rmd
- Changelog: NEWS.md
- License: LICENSE
Awesome Lists containing this project
README
---
output: github_document
---
# Explainable Ensemble Trees (e2tree)
[](https://github.com/massimoaria/e2tree/actions/workflows/R-CMD-check.yaml)
[](https://CRAN.R-project.org/package=e2tree) `r badger::badge_cran_download("e2tree", "grand-total")`
The **Explainable Ensemble Trees** (**e2tree**) key idea consists of the definition of an algorithm to represent every ensemble approach based on decision trees model using a single tree-like structure. The goal is to explain the results from the ensemble algorithm while preserving its level of accuracy, which always outperforms those provided by a decision tree. The proposed method is based on identifying the relationship tree-like structure explaining the classification or regression paths summarizing the whole ensemble process. There are two main advantages of e2tree:
- building an explainable tree that ensures the predictive performance of an RF model - allowing the decision-maker to manage with an intuitive structure (such as a tree-like structure).
In this example, we focus on Random Forest but, again, the algorithm can be generalized to every ensemble approach based on decision trees.
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%",
dpi = 300
)
```
## Setup
You can install the **developer version** of e2tree from [GitHub](https://github.com) with:
```{r eval=FALSE}
install.packages("remotes")
remotes::install_github("massimoaria/e2tree")
```
You can install the **released version** of e2tree from [CRAN](https://CRAN.R-project.org) with:
```{r eval=FALSE}
if (!require("e2tree", quietly=TRUE)) install.packages("e2tree")
```
```{r warning=FALSE, message=FALSE}
require(e2tree)
require(randomForest)
require(ranger)
require(dplyr)
require(ggplot2)
if (!(require(rsample, quietly=TRUE))){install.packages("rsample"); require(rsample, quietly=TRUE)}
options(dplyr.summarise.inform = FALSE)
```
```{r set-theme, include=FALSE}
theme_set(
theme_classic() +
theme(
plot.background = element_rect(fill = "transparent", colour = NA),
panel.background = element_rect(fill = "transparent", colour = NA)
)
)
knitr::opts_chunk$set(dev.args = list(bg = "transparent"))
```
## S3 Classes and Methods
The **e2tree** package uses a proper S3 class system. The main classes and their associated methods are:
| Class | Methods |
|-------|---------|
| `e2tree` | `print`, `summary`, `plot`, `predict`, `fitted`, `residuals`, `as.rpart`, `nodes`, `e2splits` |
| `eValidation` | `print`, `summary`, `plot`, `measures`, `proximity` |
| `loi` | `print`, `summary`, `plot` |
| `loi_perm` | `print`, `summary`, `plot` |
E2Tree objects can also be converted to other formats for interoperability:
- `as.rpart()` converts to `rpart` format for use with `rpart.plot`
- `as.party()` converts to `partykit`'s `constparty` format (if partykit is installed)
## Example 1: IRIS dataset (Classification)
Starting from the IRIS dataset, we train an ensemble tree using the randomForest package and then use e2tree to obtain an explainable tree synthesis of the ensemble classifier.
```{r}
# Set random seed to make results reproducible:
set.seed(0)
# Initialize the split
iris_split <- iris %>% initial_split(prop = 0.6)
iris_split
# Assign the data to the correct sets
training <- iris_split %>% training()
validation <- iris_split %>% testing()
response_training <- training[,5]
response_validation <- validation[,5]
```
Train a Random Forest model with 1000 weak learners
```{r}
# Perform training with "ranger" or "randomForest" package:
## RF with "ranger" package
ensemble <- ranger(Species ~ ., data = training, num.trees = 1000, importance = 'impurity')
## RF with "randomForest" package
#ensemble = randomForest(Species ~ ., data = training, importance = TRUE, proximity = TRUE)
```
Create the dissimilarity matrix between observations:
```{r}
D = createDisMatrix(ensemble, data = training, label = "Species", parallel = list(active = FALSE, no_cores = NULL))
```
Build an explainable tree for RF:
```{r}
setting=list(impTotal=0.1, maxDec=0.01, n=2, level=5)
tree <- e2tree(Species ~ ., data = training, D, ensemble, setting)
```
### S3 methods for e2tree objects
The `e2tree` class supports standard S3 methods for inspecting the fitted model:
**Print** --- compact model overview:
```{r}
print(tree)
```
**Summary** --- full model details including terminal nodes and decision rules:
```{r}
summary(tree)
```
**Plot** --- tree visualization via `rpart.plot`:
```{r}
plot(tree, ensemble)
```
### Accessor functions
Accessor functions provide a clean interface to extract components without exposing the internal structure:
```{r}
# Extract terminal nodes
nodes(tree, terminal = TRUE)
# Extract split information
str(e2splits(tree), max.level = 1)
```
### Coercion to other formats
E2Tree objects can be converted to standard tree formats for use with other packages:
```{r}
# Convert to rpart format
rpart_obj <- as.rpart(tree, ensemble)
# Convert to partykit format (if installed)
if (requireNamespace("partykit", quietly = TRUE)) {
party_obj <- partykit::as.party(tree)
plot(party_obj)
}
```
### Prediction
Use the standard `predict()` method for prediction on new data:
```{r}
# Predict on validation set
pred <- predict(tree, newdata = validation, target = "virginica")
head(pred)
```
Comparison of predictions (training sample) of RF and e2tree
```{r}
# Training predictions
pred_train <- predict(tree, newdata = training, target = "virginica")
# "ranger" package
table(pred_train$fit, ensemble$predictions)
# "randomForest" package
#table(pred_train$fit, ensemble$predicted)
```
Comparison of predictions (training sample) of RF and correct response
```{r}
# "ranger" package
table(ensemble$predictions, response_training)
## "randomForest" package
#table(ensemble$predicted, response_training)
```
Comparison of predictions (training sample) of e2tree and correct response
```{r}
table(pred_train$fit, response_training)
```
Fitted values for the training data:
```{r}
head(fitted(tree))
```
### Variable importance
Variable importance is automatically detected as classification or regression:
```{r}
V <- vimp(tree, training)
V$vimp
V$g_imp
V$g_acc
```
### Prediction on validation sample
```{r}
ensemble.pred <- predict(ensemble, validation[,-5])
pred_val <- predict(tree, newdata = validation, target = "virginica")
```
Comparison of predictions (validation sample) of RF and e2tree
```{r}
## "ranger" package
table(pred_val$fit, ensemble.pred$predictions)
## "randomForest" package
#table(pred_val$fit, ensemble.pred$predicted)
```
Comparison of predictions (validation sample) of e2tree and correct response
```{r}
table(pred_val$fit, response_validation)
roc_res <- roc(response_validation, pred_val$score, target="virginica")
roc_res$auc
```
## Validation of the E2Tree Structure
A critical question when using E2Tree is: *how well does the single tree capture the structure of the original ensemble?*
Assessing the fidelity of this reconstruction requires measuring **agreement** between the ensemble and E2Tree proximity matrices --- a fundamentally different question from measuring their **association**. The distinction parallels the classical one between *correlation* and *concordance* in method comparison studies (Bland & Altman, 1986; Lin, 1989): two proximity matrices can be perfectly correlated yet systematically disagree in their actual values. The Mantel test, being scale-invariant, would declare perfect association in such a case. But for E2Tree validation, we need to know whether the *actual proximity values* are faithfully reproduced.
The `eValidation()` function supports two approaches via the `test` argument:
- `test = "mantel"`: The classical Mantel test for *association*
- `test = "measures"`: A family of divergence/similarity measures for *agreement*
- `test = "both"` (default): Both approaches
### Divergence and similarity measures
| Measure | Type | Range | What it measures |
|---------|------|-------|------------------|
| **nLoI** | divergence | [0, 1] | Normalized Loss of Interpretability --- weighted divergence with diagnostic decomposition |
| **Hellinger** | divergence | [0, 1] | Hellinger distance --- robust to sparse matrices |
| **wRMSE** | divergence | [0, 1] | Weighted RMSE --- emphasizes high-proximity regions |
| **RV** | similarity | [0, 1] | RV coefficient --- global structural similarity (scale-invariant) |
| **SSIM** | similarity | [-1, 1] | Structural Similarity Index --- captures local block patterns |
All measures are tested simultaneously using a **unified row/column permutation test**.
### Running the validation
```{r}
val <- eValidation(training, tree, D, test = "both", graph = FALSE, n_perm = 999, seed = 42)
```
**Print** --- compact results with Mantel test and all measures:
```{r}
print(val)
```
**Summary** --- includes the LoI diagnostic decomposition:
```{r}
summary(val)
```
**Plot** --- heatmaps, null distribution, and LoI decomposition:
```{r fig.width=10, fig.height=8}
plot(val)
```
### Extracting results with accessors
Use accessor functions instead of direct `$` access:
```{r}
# Validation measures table
measures(val)
# Proximity matrices
prox <- proximity(val, type = "both")
str(prox, max.level = 1)
```
### The nLoI Decomposition
The nLoI is unique among the measures because it decomposes into two interpretable components:
- **LoI_in** (within-node): measures how well the E2Tree reproduces the ensemble's proximity values for pairs it groups *together*.
- **LoI_out** (between-node): measures the ensemble proximity lost for pairs that E2Tree *separates* into different nodes.
Since the number of within-node and between-node pairs can differ dramatically, the `loi()` function reports **per-pair averages** (`mean_in` and `mean_out`) that enable meaningful comparison:
```{r}
O <- proximity(val, type = "ensemble")
O_hat <- proximity(val, type = "e2tree")
result <- loi(O, O_hat)
summary(result)
```
The per-pair averages provide actionable diagnostics:
- **mean_out > 0.3**: the tree is splitting apart pairs with substantial ensemble proximity --- consider more terminal nodes
- **mean_out < 0.1**: the partition correctly separates low-proximity pairs --- tree structure is well-placed
- **mean_in > 0.1**: within-node calibration error is high --- check proximity estimation
- **mean_in < 0.01**: excellent within-node match between E2Tree and ensemble
### Standalone LoI permutation test
For a quick significance assessment:
```{r}
perm <- loi_perm(O, O_hat, n_perm = 999, seed = 42)
print(perm)
```
```{r fig.width=10, fig.height=5}
plot(perm)
```