{"id":21428300,"url":"https://github.com/mlverse/tabnet","last_synced_at":"2025-04-12T22:36:28.576Z","repository":{"id":37028029,"uuid":"304717861","full_name":"mlverse/tabnet","owner":"mlverse","description":"An R implementation of TabNet","archived":false,"fork":false,"pushed_at":"2025-04-12T16:35:51.000Z","size":38248,"stargazers_count":110,"open_issues_count":21,"forks_count":14,"subscribers_count":5,"default_branch":"main","last_synced_at":"2025-04-12T16:45:28.177Z","etag":null,"topics":["tabnet"],"latest_commit_sha":null,"homepage":"https://mlverse.github.io/tabnet/","language":"R","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/mlverse.png","metadata":{"files":{"readme":"README.Rmd","changelog":"NEWS.md","contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2020-10-16T19:13:14.000Z","updated_at":"2025-04-09T18:43:16.000Z","dependencies_parsed_at":"2023-12-02T16:24:37.263Z","dependency_job_id":"bea3e42a-f008-4a36-ade2-2a2527897a0b","html_url":"https://github.com/mlverse/tabnet","commit_stats":{"total_commits":378,"total_committers":11,"mean_commits":34.36363636363637,"dds":0.6296296296296297,"last_synced_commit":"8beb6909e237776cb036431041c53e4fea164d2f"},"previous_names":[],"tags_count":7,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mlverse%2Ftabnet","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mlverse%2Ftabnet/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mlverse%2Ftabnet/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mlverse%2Ftabnet/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/mlverse","download_url":"https://codeload.github.com/mlverse/tabnet/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248642844,"owners_count":21138352,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["tabnet"],"created_at":"2024-11-22T22:12:33.850Z","updated_at":"2025-04-12T22:36:28.556Z","avatar_url":"https://github.com/mlverse.png","language":"R","funding_links":[],"categories":[],"sub_categories":[],"readme":"---\noutput: github_document\n---\n\n\u003c!-- README.md is generated from README.Rmd. Please edit that file --\u003e\n\n```{r, include = FALSE}\nknitr::opts_chunk$set(\n  collapse = TRUE,\n  comment = \"#\u003e\",\n  fig.path = \"man/figures/README-\",\n  out.width = \"100%\"\n)\n```\n\n# tabnet\n\n\u003c!-- badges: start --\u003e\n\n[![R build status](https://github.com/mlverse/tabnet/workflows/R-CMD-check/badge.svg)](https://github.com/mlverse/tabnet/actions) [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html) [![CRAN status](https://www.r-pkg.org/badges/version/tabnet)](https://CRAN.R-project.org/package=tabnet) [![](https://cranlogs.r-pkg.org/badges/tabnet)](https://cran.r-project.org/package=tabnet) [![Discord](https://img.shields.io/discord/837019024499277855?logo=discord)](https://discord.com/invite/s3D5cKhBkx)\n\n\u003c!-- badges: end --\u003e\n\nAn R implementation of: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) [(Sercan O. Arik, Tomas Pfister)]( \t\nhttps://doi.org/10.48550/arXiv.1908.07442).   \nThe code in this repository is an R port using the [torch](https://github.com/mlverse/torch) package of [dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet) PyTorch's implementation.  \nTabNet is augmented with [Coherent Hierarchical Multi-label Classification Networks](https://proceedings.neurips.cc//paper/2020/file/6dd4e10e3296fa63738371ec0d5df818-Paper.pdf) [(Eleonora Giunchiglia et Al.)]( \thttps://doi.org/10.48550/arXiv.2010.10151) for hierarchical outcomes.\n\n## Installation\n\nYou can install the released version from CRAN with:\n\n``` r\ninstall.packages(\"tabnet\")\n```\n\nThe development version can be installed from [GitHub](https://github.com/mlverse/tabnet) with:\n\n``` r\n# install.packages(\"remotes\")\nremotes::install_github(\"mlverse/tabnet\")\n```\n\n## Basic Binary Classification Example\n\nHere we show a **binary classification** example of the `attrition` dataset, using a **recipe** for dataset input specification.\n\n```{r model-fit}\nlibrary(tabnet)\nsuppressPackageStartupMessages(library(recipes))\nlibrary(yardstick)\nlibrary(ggplot2)\nset.seed(1)\n\ndata(\"attrition\", package = \"modeldata\")\ntest_idx \u003c- sample.int(nrow(attrition), size = 0.2 * nrow(attrition))\n\ntrain \u003c- attrition[-test_idx,]\ntest \u003c- attrition[test_idx,]\n\nrec \u003c- recipe(Attrition ~ ., data = train) %\u003e% \n  step_normalize(all_numeric(), -all_outcomes())\n\nfit \u003c- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3)\nautoplot(fit)\n```\n\nThe plots gives you an immediate insight about model over-fitting, and if any, the available model checkpoints available before the over-fitting\n\nKeep in mind that **regression** as well as **multi-class classification** are also available, and that you can specify dataset through **data.frame** and **formula** as well. You will find them in the package vignettes.\n\n## Model performance results\n\nAs the standard method `predict()` is used, you can rely on your usual metric functions for model performance results. Here we use {yardstick} :\n\n```{r}\nmetrics \u003c- metric_set(accuracy, precision, recall)\ncbind(test, predict(fit, test)) %\u003e% \n  metrics(Attrition, estimate = .pred_class)\n  \ncbind(test, predict(fit, test, type = \"prob\")) %\u003e% \n  roc_auc(Attrition, .pred_No)\n```\n\n## Explain model on test-set with attention map\n\nTabNet has intrinsic explainability feature through the visualization of attention map, either **aggregated**:\n\n```{r model-explain}\nexplain \u003c- tabnet_explain(fit, test)\nautoplot(explain)\n```\n\nor at **each layer** through the `type = \"steps\"` option:\n\n```{r step-explain}\nautoplot(explain, type = \"steps\")\n```\n\n## Self-supervised pretraining\n\nFor cases when a consistent part of your dataset has no outcome, TabNet offers a self-supervised training step allowing to model to capture predictors intrinsic features and predictors interactions, upfront the supervised task.\n\n```{r step-pretrain}\npretrain \u003c- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2)\nautoplot(pretrain)\n```\n\nThe example here is a toy example as the `train` dataset does actually contain outcomes. The vignette on [Self-supervised training and fine-tuning](https://mlverse.github.io/tabnet/articles/selfsupervised_training.html) will gives you the complete correct workflow step-by-step.\n\n## Missing data in predictors\n\n{tabnet} leverage the masking mechanism to deal with missing data, so you don't have to remove the entries in your dataset with some missing values in the predictors variables.\n\n# Comparison with other implementations\n\n| Group            | Feature                              |      {tabnet}      | dreamquark-ai | fast-tabnet |\n|-------------|---------------------|:--------:|:-------------:|:----------:|\n| Input format     | data-frame                            |         ✅         |      ✅       |     ✅      |\n|                  | formula                              |         ✅         |               |             |\n|                  | recipe                               |         ✅         |               |             |\n|                  | Node                                 |         ✅         |               |             |\n|                  | missings in predictor                |         ✅         |               |             |\n| Output format    | data-frame                            |         ✅         |      ✅       |     ✅      |\n|                  | workflow                             |         ✅         |               |             |\n| ML Tasks         | self-supervised learning             |         ✅         |      ✅       |             |\n|                  | classification (binary, multi-class) |         ✅         |      ✅       |     ✅      |\n|                  | regression                           |         ✅         |      ✅       |     ✅      |\n|                  | multi-outcome                        |         ✅         |      ✅       |             |\n|                  | hierarchical multi-label classif.    |         ✅         |               |             |\n| Model management | from / to file                       |         ✅         |      ✅       |      v      |\n|                  | resume from snapshot                 |         ✅         |               |             |\n|                  | training diagnostic                  |         ✅         |               |             |\n| Interpretability |                                      |         ✅         |      ✅       |     ✅      |\n| Performance      |                                      |        1 x         |    2 - 4 x    |             |\n| Code quality     | test coverage                        |        85%         |               |             |\n|                  | continuous integration               | 4 OS including GPU |               |             |\n\n: Alternative TabNet implementation features\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmlverse%2Ftabnet","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmlverse%2Ftabnet","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmlverse%2Ftabnet/lists"}