https://github.com/mlverse/tabnet
An R implementation of TabNet
https://github.com/mlverse/tabnet
tabnet
Last synced: about 1 year ago
JSON representation
An R implementation of TabNet
- Host: GitHub
- URL: https://github.com/mlverse/tabnet
- Owner: mlverse
- License: other
- Created: 2020-10-16T19:13:14.000Z (over 5 years ago)
- Default Branch: main
- Last Pushed: 2025-04-12T16:35:51.000Z (about 1 year ago)
- Last Synced: 2025-04-12T16:45:28.177Z (about 1 year ago)
- Topics: tabnet
- Language: R
- Homepage: https://mlverse.github.io/tabnet/
- Size: 36.5 MB
- Stars: 110
- Watchers: 5
- Forks: 14
- Open Issues: 21
-
Metadata Files:
- Readme: README.Rmd
- Changelog: NEWS.md
- License: LICENSE
Awesome Lists containing this project
README
---
output: github_document
---
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# tabnet
[](https://github.com/mlverse/tabnet/actions) [](https://lifecycle.r-lib.org/articles/stages.html) [](https://CRAN.R-project.org/package=tabnet) [](https://cran.r-project.org/package=tabnet) [](https://discord.com/invite/s3D5cKhBkx)
An R implementation of: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) [(Sercan O. Arik, Tomas Pfister)](
https://doi.org/10.48550/arXiv.1908.07442).
The code in this repository is an R port using the [torch](https://github.com/mlverse/torch) package of [dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet) PyTorch's implementation.
TabNet is augmented with [Coherent Hierarchical Multi-label Classification Networks](https://proceedings.neurips.cc//paper/2020/file/6dd4e10e3296fa63738371ec0d5df818-Paper.pdf) [(Eleonora Giunchiglia et Al.)]( https://doi.org/10.48550/arXiv.2010.10151) for hierarchical outcomes.
## Installation
You can install the released version from CRAN with:
``` r
install.packages("tabnet")
```
The development version can be installed from [GitHub](https://github.com/mlverse/tabnet) with:
``` r
# install.packages("remotes")
remotes::install_github("mlverse/tabnet")
```
## Basic Binary Classification Example
Here we show a **binary classification** example of the `attrition` dataset, using a **recipe** for dataset input specification.
```{r model-fit}
library(tabnet)
suppressPackageStartupMessages(library(recipes))
library(yardstick)
library(ggplot2)
set.seed(1)
data("attrition", package = "modeldata")
test_idx <- sample.int(nrow(attrition), size = 0.2 * nrow(attrition))
train <- attrition[-test_idx,]
test <- attrition[test_idx,]
rec <- recipe(Attrition ~ ., data = train) %>%
step_normalize(all_numeric(), -all_outcomes())
fit <- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3)
autoplot(fit)
```
The plots gives you an immediate insight about model over-fitting, and if any, the available model checkpoints available before the over-fitting
Keep in mind that **regression** as well as **multi-class classification** are also available, and that you can specify dataset through **data.frame** and **formula** as well. You will find them in the package vignettes.
## Model performance results
As the standard method `predict()` is used, you can rely on your usual metric functions for model performance results. Here we use {yardstick} :
```{r}
metrics <- metric_set(accuracy, precision, recall)
cbind(test, predict(fit, test)) %>%
metrics(Attrition, estimate = .pred_class)
cbind(test, predict(fit, test, type = "prob")) %>%
roc_auc(Attrition, .pred_No)
```
## Explain model on test-set with attention map
TabNet has intrinsic explainability feature through the visualization of attention map, either **aggregated**:
```{r model-explain}
explain <- tabnet_explain(fit, test)
autoplot(explain)
```
or at **each layer** through the `type = "steps"` option:
```{r step-explain}
autoplot(explain, type = "steps")
```
## Self-supervised pretraining
For cases when a consistent part of your dataset has no outcome, TabNet offers a self-supervised training step allowing to model to capture predictors intrinsic features and predictors interactions, upfront the supervised task.
```{r step-pretrain}
pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2)
autoplot(pretrain)
```
The example here is a toy example as the `train` dataset does actually contain outcomes. The vignette on [Self-supervised training and fine-tuning](https://mlverse.github.io/tabnet/articles/selfsupervised_training.html) will gives you the complete correct workflow step-by-step.
## Missing data in predictors
{tabnet} leverage the masking mechanism to deal with missing data, so you don't have to remove the entries in your dataset with some missing values in the predictors variables.
# Comparison with other implementations
| Group | Feature | {tabnet} | dreamquark-ai | fast-tabnet |
|-------------|---------------------|:--------:|:-------------:|:----------:|
| Input format | data-frame | ✅ | ✅ | ✅ |
| | formula | ✅ | | |
| | recipe | ✅ | | |
| | Node | ✅ | | |
| | missings in predictor | ✅ | | |
| Output format | data-frame | ✅ | ✅ | ✅ |
| | workflow | ✅ | | |
| ML Tasks | self-supervised learning | ✅ | ✅ | |
| | classification (binary, multi-class) | ✅ | ✅ | ✅ |
| | regression | ✅ | ✅ | ✅ |
| | multi-outcome | ✅ | ✅ | |
| | hierarchical multi-label classif. | ✅ | | |
| Model management | from / to file | ✅ | ✅ | v |
| | resume from snapshot | ✅ | | |
| | training diagnostic | ✅ | | |
| Interpretability | | ✅ | ✅ | ✅ |
| Performance | | 1 x | 2 - 4 x | |
| Code quality | test coverage | 85% | | |
| | continuous integration | 4 OS including GPU | | |
: Alternative TabNet implementation features