Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/radi-cho/gatedtabtransformer
A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.
https://github.com/radi-cho/gatedtabtransformer
classification machine-learning tabular-data
Last synced: 29 days ago
JSON representation
A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.
- Host: GitHub
- URL: https://github.com/radi-cho/gatedtabtransformer
- Owner: radi-cho
- License: mit
- Created: 2021-12-14T19:16:28.000Z (about 3 years ago)
- Default Branch: master
- Last Pushed: 2023-02-04T09:59:50.000Z (almost 2 years ago)
- Last Synced: 2024-05-01T17:52:40.425Z (8 months ago)
- Topics: classification, machine-learning, tabular-data
- Language: Jupyter Notebook
- Homepage:
- Size: 60.4 MB
- Stars: 91
- Watchers: 3
- Forks: 5
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# The GatedTabTransformer.
A deep learning tabular classification architecture inspired by [TabTransformer](https://arxiv.org/abs/2012.06678) with integrated [gated](https://arxiv.org/abs/2105.08050) multilayer perceptron. Check out our paper on [arXiv](https://arxiv.org/abs/2201.00199). Applications and usage demonstrations are available [here](https://github.com/radi-cho/GatedTabTransformer-Applications).
## Usage
```python
import torch
import torch.nn as nn
from gated_tab_transformer import GatedTabTransformermodel = GatedTabTransformer(
categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
num_continuous = 10, # number of continuous values
transformer_dim = 32, # dimension, paper set at 32
dim_out = 1, # binary prediction, but could be anything
transformer_depth = 6, # depth, paper recommended 6
transformer_heads = 8, # heads, paper recommends 8
attn_dropout = 0.1, # post-attention dropout
ff_dropout = 0.1, # feed forward dropout
mlp_act = nn.LeakyReLU(0), # activation for final mlp, defaults to relu, but could be anything else (selu, etc.)
mlp_depth=4, # mlp hidden layers depth
mlp_dimension=32, # dimension of mlp layers
gmlp_enabled=True # gmlp or standard mlp
)x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10) # assume continuous values are already normalized individuallypred = model(x_categ, x_cont)
print(pred)
```## Citation
```bibtex
@misc{cholakov2022gatedtabtransformer,
title={The GatedTabTransformer. An enhanced deep learning architecture for tabular modeling},
author={Radostin Cholakov and Todor Kolev},
year={2022},
eprint={2201.00199},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```