{"id":19530186,"url":"https://github.com/lazarust/sklx","last_synced_at":"2025-05-07T16:25:27.585Z","repository":{"id":260201048,"uuid":"820969972","full_name":"lazarust/sklx","owner":"lazarust","description":"A scikit-learn compatible neural network library that wraps MLX. ","archived":false,"fork":false,"pushed_at":"2025-04-07T00:07:38.000Z","size":306,"stargazers_count":9,"open_issues_count":6,"forks_count":0,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-04-07T00:24:30.949Z","etag":null,"topics":["mlx","scikit-learn"],"latest_commit_sha":null,"homepage":"https://sklx.readthedocs.io/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"bsd-3-clause","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lazarust.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":"docs/roadmap.md","authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-06-27T14:41:00.000Z","updated_at":"2025-04-07T00:07:41.000Z","dependencies_parsed_at":null,"dependency_job_id":"aec0bad9-03b5-4560-9502-82efa0ad18ca","html_url":"https://github.com/lazarust/sklx","commit_stats":null,"previous_names":["lazarust/sklx"],"tags_count":2,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lazarust%2Fsklx","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lazarust%2Fsklx/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lazarust%2Fsklx/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lazarust%2Fsklx/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lazarust","download_url":"https://codeload.github.com/lazarust/sklx/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":252914102,"owners_count":21824312,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["mlx","scikit-learn"],"created_at":"2024-11-11T01:29:32.371Z","updated_at":"2025-05-07T16:25:27.576Z","avatar_url":"https://github.com/lazarust.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# SKLX\n[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)\n[![Pytest](https://github.com/lazarust/sklx/actions/workflows/pytest.yml/badge.svg)](https://github.com/lazarust/sklx/actions/workflows/pytest.yml)\n[![image](https://img.shields.io/pypi/v/sklx.svg)](https://pypi.org/project/sklx/)\n\nA scikit-learn compatible neural network library that wraps MLX.\nHighly inspired by [skorch](https://github.com/skorch-dev/skorch).\n\n## Examples\n\n```python\nimport numpy as np\nfrom sklearn.datasets import make_classification\nfrom mlx import nn\nfrom sklx import NeuralNetClassifier\n\nX, y = make_classification(1000, 20, n_informative=10, random_state=0)\nX = X.astype(np.float32)\ny = y.astype(np.int64)\n\nclass MyModule(nn.Module):\n    def __init__(self, num_units=10, nonlin=nn.ReLU()):\n        super().__init__()\n\n        self.dense0 = nn.Linear(20, num_units)\n        self.nonlin = nonlin\n        self.dropout = nn.Dropout(0.5)\n        self.dense1 = nn.Linear(num_units, num_units)\n        self.output = nn.Linear(num_units, 2)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, X, **kwargs):\n        X = self.nonlin(self.dense0(X))\n        X = self.dropout(X)\n        X = self.nonlin(self.dense1(X))\n        X = self.softmax(self.output(X))\n        return X\n\nnet = NeuralNetClassifier(\n    MyModule,\n    max_epochs=10,\n    lr=0.1,\n)\n\nnet.fit(X, y)\ny_proba = net.predict_proba(X)\n```\n\nIn an sklearn Pipeline:\n\n```python\nfrom sklearn.pipeline import Pipeline\nfrom sklearn.preprocessing import StandardScaler\n\npipe = Pipeline([\n    ('scale', StandardScaler()),\n    ('net', net),\n])\n\npipe.fit(X, y)\ny_proba = pipe.predict_proba(X)\n```\n\nWith grid search:\n\n```python\nfrom sklearn.model_selection import GridSearchCV\n\nparams = {\n    'lr': [0.01, 0.02],\n    'max_epochs': [10, 20],\n    'module__num_units': [10, 20],\n}\ngs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)\n\ngs.fit(X, y)\nprint(\"best score: {:.3f}, best params: {}\".format(gs.best_score_, gs.best_params_))\n```\n\n## Future Roadmap\n\n1. Completing Feature Parity with [Skorch](https://github.com/skorch-dev/skorch)\n   1. ~Pipeline Support~\n   2. ~Grid Search Support~\n   3. Learning Rate Scheduler https://github.com/lazarust/sklx/issues/6\n   4. Scoring https://github.com/lazarust/sklx/issues/7\n   5. Early Stopping https://github.com/lazarust/sklx/issues/8\n   6. Checkpointing https://github.com/lazarust/sklx/issues/9\n   7. Parameter Freezing https://github.com/lazarust/sklx/issues/10\n   8. Progress Bar https://github.com/lazarust/sklx/issues/11\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flazarust%2Fsklx","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flazarust%2Fsklx","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flazarust%2Fsklx/lists"}