Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/skblaz/san
Attention-based feature ranking for propositional data.
https://github.com/skblaz/san
attention-mechanism deep-neural-networks feature-ranking machine-learning
Last synced: 10 days ago
JSON representation
Attention-based feature ranking for propositional data.
- Host: GitHub
- URL: https://github.com/skblaz/san
- Owner: SkBlaz
- License: gpl-3.0
- Created: 2020-04-14T14:50:23.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2023-10-15T09:43:20.000Z (about 1 year ago)
- Last Synced: 2023-10-16T11:57:59.933Z (about 1 year ago)
- Topics: attention-mechanism, deep-neural-networks, feature-ranking, machine-learning
- Language: Python
- Size: 708 KB
- Stars: 23
- Watchers: 3
- Forks: 9
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Feature ranking with self-attention networks
This is the repository of the SAN paper, found here:```
@inproceedings{DBLP:conf/ecai/SkrljDLP20,
author = {Blaz Skrlj and
Saso Dzeroski and
Nada Lavrac and
Matej Petkovic},
editor = {Giuseppe De Giacomo and
Alejandro Catal{\'{a}} and
Bistra Dilkina and
Michela Milano and
Sen{\'{e}}n Barro and
Alberto Bugar{\'{\i}}n and
J{\'{e}}r{\^{o}}me Lang},
title = {Feature Importance Estimation with Self-Attention Networks},
booktitle = {{ECAI} 2020 - 24th European Conference on Artificial Intelligence,
29 August-8 September 2020, Santiago de Compostela, Spain, August
29 - September 8, 2020 - Including 10th Conference on Prestigious
Applications of Artificial Intelligence {(PAIS} 2020)},
series = {Frontiers in Artificial Intelligence and Applications},
volume = {325},
pages = {1491--1498},
publisher = {{IOS} Press},
year = {2020},
url = {https://doi.org/10.3233/FAIA200256},
doi = {10.3233/FAIA200256},
timestamp = {Tue, 15 Sep 2020 15:08:42 +0200},
biburl = {https://dblp.org/rec/conf/ecai/SkrljDLP20.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```http://ecai2020.eu/papers/1721_paper.pdf
(please, cite if you are using it!).
Note that the full code with datasets to reproduce the paper can be found here: https://gitlab.com/skblaz/attentionrank (the code is in benchmark-ready form). The purpose of this repository is to provide all functionality in a user-friendly way. Disclaimer: this code was not extensively benchmarked and can contain bugs. If you find one, please open an issue.# Installing SANs
```
python setup.py install
```or
```
pip install git+https://github.com/SkBlaz/san
```
# Using SANs
A simple usecase is given next:```
from scipy import sparse
import numpy as np
from sklearn.datasets import load_breast_cancer
import san
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.feature_selection import chi2,f_classif,mutual_info_classif
from sklearn.ensemble import RandomForestClassifiersns.set_style("whitegrid")
dataobj = load_breast_cancer()
X = dataobj['data']
Y = dataobj['target']
names = dataobj['feature_names']# let's overfit, just for demo purposes
clf = san.SAN(num_epochs = 32, num_heads = 2, batch_size = 8, dropout = 0.2, hidden_layer_size = 32)
X = sparse.csr_matrix(X)
clf.fit(X, Y)
preds = clf.predict(X)
global_attention_weights = clf.get_mean_attention_weights()
local_attention_matrix = clf.get_instance_attention(X)
mutual_information = mutual_info_classif(X,Y)
rf = RandomForestClassifier().fit(X,Y).feature_importances_plt.plot(names, global_attention_weights, label = "Global attention", marker = "x")
plt.plot(names, np.mean(local_attention_matrix, axis = 0), label = "Local attention - mean", marker = "x")plt.plot(names, np.max(local_attention_matrix, axis = 0), label = "Local attention - max", marker = "x")
plt.plot(names, mutual_information, label = "Mutual information", marker = ".")
plt.plot(names, rf, label = "RandomForest", marker = ".")
plt.legend(loc = 1)
plt.xticks(rotation = 90)
plt.tight_layout()
plt.show()```
![Example](images/example.png)
Example mock evaluation is shown below (examples/example_benchmark.py):
![Example](images/example1.png)