Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/stephantul/torchic
Simple linear thing in Torch, with a scikit-learn compatible API.
https://github.com/stephantul/torchic
Last synced: about 6 hours ago
JSON representation
Simple linear thing in Torch, with a scikit-learn compatible API.
- Host: GitHub
- URL: https://github.com/stephantul/torchic
- Owner: stephantul
- License: mit
- Created: 2023-08-16T13:12:12.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2023-08-24T10:47:10.000Z (over 1 year ago)
- Last Synced: 2023-08-24T12:32:28.467Z (over 1 year ago)
- Language: Python
- Size: 102 KB
- Stars: 1
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
![]()
# torchic
Simple model training in Pytorch, with a scikit-learn compatible API.
It has the following features:
* Scikit-learn like API (i.e., using `fit` and `predict`)
* Supports numpy arrays and torch tensors out of the box
* Automatically converts your tensors between devices## Example
The example below classifies 20 newsgroups, which is pre-vectorized using a CountVectorizer, courtesy of scikit-learn. This example requires that scikit-learn is installed.
```python
import numpy as npfrom sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_supportfrom torchic import Torchic
# NOTE: change this to 'cuda' or 'mps' if you want acceleration.
DEVICE = "cpu"X, y = fetch_20newsgroups_vectorized(return_X_y=True, remove=("headers", "footers"), subset="train")
X = X[y < 10]
y = y[y < 10]
X = np.asarray(X.todense())X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, random_state=44, test_size=.1)
n_features, n_labels = X_train.shape[1], len(set(y))# Torchic stuff begins here.
t = Torchic(n_features, n_labels, learning_rate=1e-4).to(DEVICE)
t.fit(X_train, y_train, batch_size=128)pred = t.predict(X_test)
print(precision_recall_fscore_support(y_test, pred, average="macro"))
```## TODO:
* Add docstrings
* Add additional unit tests