{"id":13438330,"url":"https://github.com/maciejkula/rustlearn","last_synced_at":"2025-04-08T10:20:01.112Z","repository":{"id":49974192,"uuid":"47362929","full_name":"maciejkula/rustlearn","owner":"maciejkula","description":"Machine learning crate for Rust","archived":false,"fork":false,"pushed_at":"2021-06-07T09:09:59.000Z","size":10825,"stargazers_count":632,"open_issues_count":13,"forks_count":54,"subscribers_count":22,"default_branch":"master","last_synced_at":"2025-04-01T08:42:55.183Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Rust","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/maciejkula.png","metadata":{"files":{"readme":"readme.md","changelog":"changelog.md","contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2015-12-03T21:48:17.000Z","updated_at":"2025-03-20T13:23:10.000Z","dependencies_parsed_at":"2022-07-29T23:39:29.062Z","dependency_job_id":null,"html_url":"https://github.com/maciejkula/rustlearn","commit_stats":null,"previous_names":[],"tags_count":6,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/maciejkula%2Frustlearn","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/maciejkula%2Frustlearn/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/maciejkula%2Frustlearn/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/maciejkula%2Frustlearn/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/maciejkula","download_url":"https://codeload.github.com/maciejkula/rustlearn/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247819940,"owners_count":21001394,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-07-31T03:01:04.608Z","updated_at":"2025-04-08T10:20:01.088Z","avatar_url":"https://github.com/maciejkula.png","language":"Rust","funding_links":[],"categories":["Libraries","库 Libraries","库","Rust","人工智能（Artificial Intelligence）","Frameworks"],"sub_categories":["Artificial Intelligence","Machine learning","人工智能 Artificial Intelligence","人工智能","General-Purpose Machine Learning","机器学习（Machine Learning）"],"readme":"# rustlearn\n\n[![Circle CI](https://circleci.com/gh/maciejkula/rustlearn.svg?style=svg)](https://circleci.com/gh/maciejkula/rustlearn)\n[![Crates.io](https://img.shields.io/crates/v/rustlearn.svg)](https://crates.io/crates/rustlearn)\n\nA machine learning package for Rust.\n\nFor full usage details, see the [API documentation](https://maciejkula.github.io/rustlearn/doc/rustlearn/).\n\n## Introduction\n\nThis crate contains reasonably effective\nimplementations of a number of common machine learning algorithms.\n\nAt the moment, `rustlearn` uses its own basic dense and sparse array types, but I will be happy\nto use something more robust once a clear winner in that space emerges.\n\n## Features\n\n### Matrix primitives\n\n- [dense matrices](https://maciejkula.github.io/rustlearn/doc/rustlearn/array/dense/index.html)\n- [sparse matrices](https://maciejkula.github.io/rustlearn/doc/rustlearn/array/sparse/index.html)\n\n### Models\n\n- [logistic regression](https://maciejkula.github.io/rustlearn/doc/rustlearn/linear_models/sgdclassifier/index.html) using stochastic gradient descent,\n- [support vector machines](https://maciejkula.github.io/rustlearn/doc/rustlearn/svm/libsvm/svc/index.html) using the `libsvm` library,\n- [decision trees](https://maciejkula.github.io/rustlearn/doc/rustlearn/trees/decision_tree/index.html) using the CART algorithm,\n- [random forests](https://maciejkula.github.io/rustlearn/doc/rustlearn/ensemble/random_forest/index.html) using CART decision trees, and\n- [factorization machines](https://maciejkula.github.io/rustlearn/doc/rustlearn/factorization/factorization_machines/index.html).\n\nAll the models support fitting and prediction on both dense and sparse data, and the implementations\nshould be roughly competitive with Python `sklearn` implementations, both in accuracy and performance.\n\n## Cross-validation\n\n- [k-fold cross-validation](https://maciejkula.github.io/rustlearn/doc/rustlearn/cross_validation/cross_validation/index.html)\n- [shuffle split](https://maciejkula.github.io/rustlearn/doc/rustlearn/cross_validation/shuffle_split/index.html)\n\n## Metrics\n\n- [accuracy](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/fn.accuracy_score.html)\n- [ROC AUC score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.roc_auc_score.html)\n- [dcg_score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.dcg_score.html)\n- [ndcg_score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.ndcg_score.html)\n- [mean absolute error](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.mean_absolute_error.html)\n- [mean squared error](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.mean_squared_error.html)\n\n## Parallelization\n\nA number of models support both parallel model fitting and prediction.\n\n### Model serialization\n\nModel serialization is supported via `serde`.\n\n## Using `rustlearn`\nUsage should be straightforward.\n\n- import the prelude for all the linear algebra primitives and common traits:\n\n```rust\nuse rustlearn::prelude::*;\n```\n\n- import individual models and utilities from submodules:\n\n```rust\nuse rustlearn::prelude::*;\n\nuse rustlearn::linear_models::sgdclassifier::Hyperparameters;\n// more imports\n```\n\n## Examples\n\n### Logistic regression\n\n```rust\nuse rustlearn::prelude::*;\nuse rustlearn::datasets::iris;\nuse rustlearn::cross_validation::CrossValidation;\nuse rustlearn::linear_models::sgdclassifier::Hyperparameters;\nuse rustlearn::metrics::accuracy_score;\n\n\nlet (X, y) = iris::load_data();\n\nlet num_splits = 10;\nlet num_epochs = 5;\n\nlet mut accuracy = 0.0;\n\nfor (train_idx, test_idx) in CrossValidation::new(X.rows(), num_splits) {\n\n    let X_train = X.get_rows(\u0026train_idx);\n    let y_train = y.get_rows(\u0026train_idx);\n    let X_test = X.get_rows(\u0026test_idx);\n    let y_test = y.get_rows(\u0026test_idx);\n\n    let mut model = Hyperparameters::new(X.cols())\n                                    .learning_rate(0.5)\n                                    .l2_penalty(0.0)\n                                    .l1_penalty(0.0)\n                                    .one_vs_rest();\n\n    for _ in 0..num_epochs {\n        model.fit(\u0026X_train, \u0026y_train).unwrap();\n    }\n\n    let prediction = model.predict(\u0026X_test).unwrap();\n    accuracy += accuracy_score(\u0026y_test, \u0026prediction);\n}\n\naccuracy /= num_splits as f32;\n\n```\n\n### Random forest\n\n```rust\nuse rustlearn::prelude::*;\n\nuse rustlearn::ensemble::random_forest::Hyperparameters;\nuse rustlearn::datasets::iris;\nuse rustlearn::trees::decision_tree;\n\nlet (data, target) = iris::load_data();\n\nlet mut tree_params = decision_tree::Hyperparameters::new(data.cols());\ntree_params.min_samples_split(10)\n    .max_features(4);\n\nlet mut model = Hyperparameters::new(tree_params, 10)\n    .one_vs_rest();\n\nmodel.fit(\u0026data, \u0026target).unwrap();\n\n// Optionally serialize and deserialize the model\n\n// let encoded = bincode::serialize(\u0026model).unwrap();\n// let decoded: OneVsRestWrapper\u003cRandomForest\u003e = bincode::deserialize(\u0026encoded).unwrap();\n\nlet prediction = model.predict(\u0026data).unwrap();\n```\n\n## Contributing\nPull requests are welcome.\n\nTo run basic tests, run `cargo test`.\n\nRunning `cargo test --features \"all_tests\" --release` runs all tests, including generated and slow tests.\nRunning `cargo bench --features bench` (only on the nightly branch) runs benchmarks.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmaciejkula%2Frustlearn","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmaciejkula%2Frustlearn","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmaciejkula%2Frustlearn/lists"}