An open API service indexing awesome lists of open source software.

https://github.com/thesoenke/pytorch-trainer

Lightweight PyTorch trainer
https://github.com/thesoenke/pytorch-trainer

pytorch trainer

Last synced: about 2 months ago
JSON representation

Lightweight PyTorch trainer

Awesome Lists containing this project

README

          

# PyTorch Trainer

Lightweight wrapper around PyTorch. Removes boilerplate code to focus on the important parts.

## Example
```python
import os

import torch
import torchvision.transforms as transforms
from module import Module
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pytorch_trainer import EarlyStopping, ModelCheckpoint, Module, Trainer

class MNISTModel(Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_num):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
return {'loss': loss}

def validation_step(self, batch, batch_num):
x, y = batch
output = self.forward(x)
return {'val_loss': F.cross_entropy(output, y)}

def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

def val_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)

checkpoint_callback = ModelCheckpoint(
directory='./checkpoints',
monitor='val_loss',
save_best_only=True,
mode='min'
)
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.00,
patience=5,
mode='min'
)

model = MNISTModel()
trainer = Trainer(
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
)
trainer.fit(model)
```

Inspired by [PyTorch Lightning](https://github.com/williamFalcon/pytorch-lightning)