https://github.com/lazarust/sklx
A scikit-learn compatible neural network library that wraps MLX.
https://github.com/lazarust/sklx
mlx scikit-learn
Last synced: about 22 hours ago
JSON representation
A scikit-learn compatible neural network library that wraps MLX.
- Host: GitHub
- URL: https://github.com/lazarust/sklx
- Owner: lazarust
- License: bsd-3-clause
- Created: 2024-06-27T14:41:00.000Z (10 months ago)
- Default Branch: main
- Last Pushed: 2025-04-07T00:07:38.000Z (about 1 month ago)
- Last Synced: 2025-04-07T00:24:30.949Z (about 1 month ago)
- Topics: mlx, scikit-learn
- Language: Python
- Homepage: https://sklx.readthedocs.io/
- Size: 299 KB
- Stars: 9
- Watchers: 3
- Forks: 0
- Open Issues: 6
-
Metadata Files:
- Readme: README.md
- License: LICENSE
- Roadmap: docs/roadmap.md
Awesome Lists containing this project
README
# SKLX
[](https://github.com/astral-sh/ruff)
[](https://github.com/lazarust/sklx/actions/workflows/pytest.yml)
[](https://pypi.org/project/sklx/)A scikit-learn compatible neural network library that wraps MLX.
Highly inspired by [skorch](https://github.com/skorch-dev/skorch).## Examples
```python
import numpy as np
from sklearn.datasets import make_classification
from mlx import nn
from sklx import NeuralNetClassifierX, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super().__init__()self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
self.softmax = nn.Softmax(dim=-1)def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.softmax(self.output(X))
return Xnet = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
)net.fit(X, y)
y_proba = net.predict_proba(X)
```In an sklearn Pipeline:
```python
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScalerpipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])pipe.fit(X, y)
y_proba = pipe.predict_proba(X)
```With grid search:
```python
from sklearn.model_selection import GridSearchCVparams = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)gs.fit(X, y)
print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))
```## Future Roadmap
1. Completing Feature Parity with [Skorch](https://github.com/skorch-dev/skorch)
1. ~Pipeline Support~
2. ~Grid Search Support~
3. Learning Rate Scheduler https://github.com/lazarust/sklx/issues/6
4. Scoring https://github.com/lazarust/sklx/issues/7
5. Early Stopping https://github.com/lazarust/sklx/issues/8
6. Checkpointing https://github.com/lazarust/sklx/issues/9
7. Parameter Freezing https://github.com/lazarust/sklx/issues/10
8. Progress Bar https://github.com/lazarust/sklx/issues/11