Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/maciejkula/rustlearn
Machine learning crate for Rust
https://github.com/maciejkula/rustlearn
Last synced: 2 days ago
JSON representation
Machine learning crate for Rust
- Host: GitHub
- URL: https://github.com/maciejkula/rustlearn
- Owner: maciejkula
- License: apache-2.0
- Created: 2015-12-03T21:48:17.000Z (about 9 years ago)
- Default Branch: master
- Last Pushed: 2021-06-07T09:09:59.000Z (over 3 years ago)
- Last Synced: 2025-01-04T19:04:38.209Z (9 days ago)
- Language: Rust
- Size: 10.3 MB
- Stars: 629
- Watchers: 23
- Forks: 55
- Open Issues: 13
-
Metadata Files:
- Readme: readme.md
- Changelog: changelog.md
- License: LICENSE
Awesome Lists containing this project
- awesome-rust-cn - maciejkula/rustlearn
- awesome-rust - maciejkula/rustlearn
- awesome-rust - maciejkula/rustlearn
- awesome-rust-cn - maciejkula/rustlearn
- awesome-rust-zh - maciejkula/rustlearn - Rust 的机器学习箱。[![Circle CI](https://circleci.com/gh/maciejkula/rustlearn.svg?style=svg)](https://circleci.com/gh/maciejkula/rustlearn) (库 / 人工智能)
- awesome-rust - maciejkula/rustlearn - Machine learning library. [![Circle CI](https://circleci.com/gh/maciejkula/rustlearn.svg?style=svg)](https://app.circleci.com/pipelines/github/maciejkula/rustlearn) (Libraries / Artificial Intelligence)
- awesome-rust - maciejkula/rustlearn
- awesome-rust-zh-cn - maciejkula/rustlearn
- fucking-awesome-rust - maciejkula/rustlearn - Machine learning library. [![Circle CI](https://circleci.com/gh/maciejkula/rustlearn.svg?style=svg)](https://app.circleci.com/pipelines/github/maciejkula/rustlearn) (Libraries / Artificial Intelligence)
- fucking-awesome-rust - maciejkula/rustlearn - Machine learning library. [![Circle CI](https://circleci.com/gh/maciejkula/rustlearn.svg?style=svg)](https://app.circleci.com/pipelines/github/maciejkula/rustlearn) (Libraries / Artificial Intelligence)
README
# rustlearn
[![Circle CI](https://circleci.com/gh/maciejkula/rustlearn.svg?style=svg)](https://circleci.com/gh/maciejkula/rustlearn)
[![Crates.io](https://img.shields.io/crates/v/rustlearn.svg)](https://crates.io/crates/rustlearn)A machine learning package for Rust.
For full usage details, see the [API documentation](https://maciejkula.github.io/rustlearn/doc/rustlearn/).
## Introduction
This crate contains reasonably effective
implementations of a number of common machine learning algorithms.At the moment, `rustlearn` uses its own basic dense and sparse array types, but I will be happy
to use something more robust once a clear winner in that space emerges.## Features
### Matrix primitives
- [dense matrices](https://maciejkula.github.io/rustlearn/doc/rustlearn/array/dense/index.html)
- [sparse matrices](https://maciejkula.github.io/rustlearn/doc/rustlearn/array/sparse/index.html)### Models
- [logistic regression](https://maciejkula.github.io/rustlearn/doc/rustlearn/linear_models/sgdclassifier/index.html) using stochastic gradient descent,
- [support vector machines](https://maciejkula.github.io/rustlearn/doc/rustlearn/svm/libsvm/svc/index.html) using the `libsvm` library,
- [decision trees](https://maciejkula.github.io/rustlearn/doc/rustlearn/trees/decision_tree/index.html) using the CART algorithm,
- [random forests](https://maciejkula.github.io/rustlearn/doc/rustlearn/ensemble/random_forest/index.html) using CART decision trees, and
- [factorization machines](https://maciejkula.github.io/rustlearn/doc/rustlearn/factorization/factorization_machines/index.html).All the models support fitting and prediction on both dense and sparse data, and the implementations
should be roughly competitive with Python `sklearn` implementations, both in accuracy and performance.## Cross-validation
- [k-fold cross-validation](https://maciejkula.github.io/rustlearn/doc/rustlearn/cross_validation/cross_validation/index.html)
- [shuffle split](https://maciejkula.github.io/rustlearn/doc/rustlearn/cross_validation/shuffle_split/index.html)## Metrics
- [accuracy](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/fn.accuracy_score.html)
- [ROC AUC score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.roc_auc_score.html)
- [dcg_score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.dcg_score.html)
- [ndcg_score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.ndcg_score.html)
- [mean absolute error](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.mean_absolute_error.html)
- [mean squared error](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.mean_squared_error.html)## Parallelization
A number of models support both parallel model fitting and prediction.
### Model serialization
Model serialization is supported via `serde`.
## Using `rustlearn`
Usage should be straightforward.- import the prelude for all the linear algebra primitives and common traits:
```rust
use rustlearn::prelude::*;
```- import individual models and utilities from submodules:
```rust
use rustlearn::prelude::*;use rustlearn::linear_models::sgdclassifier::Hyperparameters;
// more imports
```## Examples
### Logistic regression
```rust
use rustlearn::prelude::*;
use rustlearn::datasets::iris;
use rustlearn::cross_validation::CrossValidation;
use rustlearn::linear_models::sgdclassifier::Hyperparameters;
use rustlearn::metrics::accuracy_score;let (X, y) = iris::load_data();
let num_splits = 10;
let num_epochs = 5;let mut accuracy = 0.0;
for (train_idx, test_idx) in CrossValidation::new(X.rows(), num_splits) {
let X_train = X.get_rows(&train_idx);
let y_train = y.get_rows(&train_idx);
let X_test = X.get_rows(&test_idx);
let y_test = y.get_rows(&test_idx);let mut model = Hyperparameters::new(X.cols())
.learning_rate(0.5)
.l2_penalty(0.0)
.l1_penalty(0.0)
.one_vs_rest();for _ in 0..num_epochs {
model.fit(&X_train, &y_train).unwrap();
}let prediction = model.predict(&X_test).unwrap();
accuracy += accuracy_score(&y_test, &prediction);
}accuracy /= num_splits as f32;
```
### Random forest
```rust
use rustlearn::prelude::*;use rustlearn::ensemble::random_forest::Hyperparameters;
use rustlearn::datasets::iris;
use rustlearn::trees::decision_tree;let (data, target) = iris::load_data();
let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
tree_params.min_samples_split(10)
.max_features(4);let mut model = Hyperparameters::new(tree_params, 10)
.one_vs_rest();model.fit(&data, &target).unwrap();
// Optionally serialize and deserialize the model
// let encoded = bincode::serialize(&model).unwrap();
// let decoded: OneVsRestWrapper = bincode::deserialize(&encoded).unwrap();let prediction = model.predict(&data).unwrap();
```## Contributing
Pull requests are welcome.To run basic tests, run `cargo test`.
Running `cargo test --features "all_tests" --release` runs all tests, including generated and slow tests.
Running `cargo bench --features bench` (only on the nightly branch) runs benchmarks.