Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/dvilaverde/sklearn2java
A parser for scikit-learn exported text models to execute in the Java runtime.
https://github.com/dvilaverde/sklearn2java
decision-tree decisiontreeclassifier java machine-learning ml random-forest randomforestclassifier sci-kit scikit scikit-learn scikitlearn-machine-learning sklearn
Last synced: 3 days ago
JSON representation
A parser for scikit-learn exported text models to execute in the Java runtime.
- Host: GitHub
- URL: https://github.com/dvilaverde/sklearn2java
- Owner: dvilaverde
- License: apache-2.0
- Created: 2022-12-11T18:52:02.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2024-04-28T18:13:58.000Z (7 months ago)
- Last Synced: 2024-04-28T19:25:35.195Z (7 months ago)
- Topics: decision-tree, decisiontreeclassifier, java, machine-learning, ml, random-forest, randomforestclassifier, sci-kit, scikit, scikit-learn, scikitlearn-machine-learning, sklearn
- Language: Java
- Homepage:
- Size: 98.6 KB
- Stars: 3
- Watchers: 2
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# SkLearn2Java
This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.
[![javadoc](https://javadoc.io/badge2/rocks.vilaverde/scikit-learn-2-java/javadoc.svg)](https://javadoc.io/doc/rocks.vilaverde/scikit-learn-2-java)
## Support
* The tree.DecisionTreeClassifier is supported
* Supports `predict()`,
* Supports `predict_proba()` when `export_text()` configured with `show_weights=True`
* The tree.RandomForestClassifier is supported
* Supports `predict()`,
* Supports `predict_proba()` when `export_text()` configured with `show_weights=True`## Installing
### Importing Maven Dependency
```xmlrocks.vilaverde
scikit-learn-2-java
1.1.0```
## DecisionTreeClassifier
As an example, a DecisionTreeClassifier model trained on the Iris dataset and exported using `sklearn.tree`
`export_text()` as shown below:```
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'], show_weights=True, max_depth=sys.maxsize)
>>> print(r)|--- petal width (cm) <= 0.80
| |--- class: 0
|--- petal width (cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- petal width (cm) > 1.75
| | |--- class: 2
```The exported text can then be executed in Java. Note that when calling `export_text` it is
recommended that `max_depth` be set to `sys.maxsize` so that the tree isn't truncated.### Java Example
In this example the iris model exported using `export_text` is parsed, features are created as a Java Map
and the decision tree is asked to predict the class.```
Reader tree = getTrainedModel("iris.model");
final Classifier decisionTree = DecisionTreeClassifier.parse(tree,
PredictionFactory.INTEGER);Features features = Features.of("sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)");
FeatureVector fv = features.newSample();
fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0);
Integer prediction = decisionTree.predict(fv);
System.out.println(prediction.toString());
```## RandomForestClassifier
To use a RandomForestClassifier that has been trained on the Iris dataset, each of the `estimators`
in the classifiers need to be and exported using `from sklearn.tree export export_text` as shown below:```
>>> from sklearn import datasets
>>> from sklearn import tree
>>> from sklearn.ensemble import RandomForestClassifier
>>>
>>> import os
>>>
>>> iris = datasets.load_iris()
>>> X = iris.data
>>> y = iris.target
>>>
>>> clf = RandomForestClassifier(n_estimators = 50, n_jobs=8)
>>> model = clf.fit(X, y)
>>>
>>> for i, t in enumerate(clf.estimators_):
>>> with open(os.path.join('/tmp/estimators', "iris-" + str(i) + ".txt"), "w") as file1:
>>> text_representation = tree.export_text(t, feature_names=iris.feature_names, show_weights=True, decimals=4, max_depth=sys.maxsize)
>>> file1.write(text_representation)
```Once all the estimators are exported into `/tmp/estimators`, you can create a TAR archive, for example:
```bash
cd /tmp/estimators
tar -czvf /tmp/iris.tgz .
```Then you can use the RandomForestClassifier class to parse the TAR archive.
```
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
...
TarArchiveInputStream tree = getArchive("iris.tgz");
final Classifier decisionTree = RandomForestClassifier.parse(tree,
PredictionFactory.DOUBLE);
```## Testing
Testing was done using models exported using sci-kit learn version 1.1.3, but should
work with newer versions of sci-kit learn.