Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/radi-cho/gatedtabtransformer

A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.
https://github.com/radi-cho/gatedtabtransformer

classification machine-learning tabular-data

Last synced: 29 days ago
JSON representation

A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

Awesome Lists containing this project

README

        

# The GatedTabTransformer.

A deep learning tabular classification architecture inspired by [TabTransformer](https://arxiv.org/abs/2012.06678) with integrated [gated](https://arxiv.org/abs/2105.08050) multilayer perceptron. Check out our paper on [arXiv](https://arxiv.org/abs/2201.00199). Applications and usage demonstrations are available [here](https://github.com/radi-cho/GatedTabTransformer-Applications).

Architecture

## Usage

```python
import torch
import torch.nn as nn
from gated_tab_transformer import GatedTabTransformer

model = GatedTabTransformer(
categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
num_continuous = 10, # number of continuous values
transformer_dim = 32, # dimension, paper set at 32
dim_out = 1, # binary prediction, but could be anything
transformer_depth = 6, # depth, paper recommended 6
transformer_heads = 8, # heads, paper recommends 8
attn_dropout = 0.1, # post-attention dropout
ff_dropout = 0.1, # feed forward dropout
mlp_act = nn.LeakyReLU(0), # activation for final mlp, defaults to relu, but could be anything else (selu, etc.)
mlp_depth=4, # mlp hidden layers depth
mlp_dimension=32, # dimension of mlp layers
gmlp_enabled=True # gmlp or standard mlp
)

x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10) # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)
print(pred)
```

## Citation

```bibtex
@misc{cholakov2022gatedtabtransformer,
title={The GatedTabTransformer. An enhanced deep learning architecture for tabular modeling},
author={Radostin Cholakov and Todor Kolev},
year={2022},
eprint={2201.00199},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```