Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/tcapelle/pytorch_mnl
A multinomial logistic regression package based on pytorch.
https://github.com/tcapelle/pytorch_mnl
Last synced: about 1 month ago
JSON representation
A multinomial logistic regression package based on pytorch.
- Host: GitHub
- URL: https://github.com/tcapelle/pytorch_mnl
- Owner: tcapelle
- License: apache-2.0
- Created: 2021-09-16T13:50:21.000Z (about 3 years ago)
- Default Branch: master
- Last Pushed: 2021-09-20T19:52:28.000Z (about 3 years ago)
- Last Synced: 2024-10-06T09:21:23.240Z (about 1 month ago)
- Language: Jupyter Notebook
- Size: 229 KB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# A simple pytorch based MNL lib
> Fit your Multinomial Logistic Regression with Pytorch## Install
`pip install pytorch_mnl`
## How to use
import the lib
```python
import pandas as pd
from pytorch_mnl.core import *
```load data
```python
data = pd.read_csv("./data/Iris.csv").drop("Id", axis=1)
```choose x, y cols:
```python
x_cols=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
target_col = 'Species'
```the number of classes to predict:
```python
n_targets = len(data[target_col].unique())
n_targets
```3
```python
X, y = prepare_data(data, x_cols=x_cols, target_col=target_col)
```we get pytorch tensors ready to use!
```python
type(X), type(y)
```(torch.Tensor, torch.Tensor)
let's split in train/valid choosing a percenage as holdout, and choose a batch size to fit our model
```python
dls = DataLoaders.from_Xy(X, y, pct=0.2, batch_size=8)
```as our model has 4 variables, we will fit a 4 MNL, with 3 targets.
```python
model = LinearMNL(len(x_cols), n_targets)
``````python
learn = Learner(dls, model)
``````python
learn.fit(25)
```epoch = 0, val_loss = 2.072, accuracy = 0.53
epoch = 1, val_loss = 1.908, accuracy = 0.53
epoch = 2, val_loss = 1.770, accuracy = 0.80
epoch = 3, val_loss = 1.657, accuracy = 0.80
epoch = 4, val_loss = 1.564, accuracy = 0.80
epoch = 5, val_loss = 1.487, accuracy = 0.80
epoch = 6, val_loss = 1.422, accuracy = 0.80
epoch = 7, val_loss = 1.368, accuracy = 0.80
epoch = 8, val_loss = 1.321, accuracy = 0.80
epoch = 9, val_loss = 1.282, accuracy = 0.83
epoch = 10, val_loss = 1.247, accuracy = 0.83
epoch = 11, val_loss = 1.217, accuracy = 0.83
epoch = 12, val_loss = 1.190, accuracy = 0.83
epoch = 13, val_loss = 1.166, accuracy = 0.83
epoch = 14, val_loss = 1.144, accuracy = 0.87
epoch = 15, val_loss = 1.125, accuracy = 0.87
epoch = 16, val_loss = 1.107, accuracy = 0.90
epoch = 17, val_loss = 1.091, accuracy = 0.90
epoch = 18, val_loss = 1.076, accuracy = 0.90
epoch = 19, val_loss = 1.063, accuracy = 0.90
epoch = 20, val_loss = 1.050, accuracy = 0.90
epoch = 21, val_loss = 1.038, accuracy = 0.90
epoch = 22, val_loss = 1.027, accuracy = 0.90
epoch = 23, val_loss = 1.016, accuracy = 0.90
epoch = 24, val_loss = 1.007, accuracy = 0.90