https://github.com/thesoenke/pytorch-trainer
Lightweight PyTorch trainer
https://github.com/thesoenke/pytorch-trainer
pytorch trainer
Last synced: about 2 months ago
JSON representation
Lightweight PyTorch trainer
- Host: GitHub
- URL: https://github.com/thesoenke/pytorch-trainer
- Owner: theSoenke
- License: mit
- Created: 2019-10-16T18:24:59.000Z (over 6 years ago)
- Default Branch: master
- Last Pushed: 2019-11-25T22:23:14.000Z (over 6 years ago)
- Last Synced: 2025-03-30T07:15:03.664Z (over 1 year ago)
- Topics: pytorch, trainer
- Language: Python
- Homepage:
- Size: 47.9 KB
- Stars: 2
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# PyTorch Trainer
Lightweight wrapper around PyTorch. Removes boilerplate code to focus on the important parts.
## Example
```python
import os
import torch
import torchvision.transforms as transforms
from module import Module
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from pytorch_trainer import EarlyStopping, ModelCheckpoint, Module, Trainer
class MNISTModel(Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_num):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
return {'loss': loss}
def validation_step(self, batch, batch_num):
x, y = batch
output = self.forward(x)
return {'val_loss': F.cross_entropy(output, y)}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
def val_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
checkpoint_callback = ModelCheckpoint(
directory='./checkpoints',
monitor='val_loss',
save_best_only=True,
mode='min'
)
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.00,
patience=5,
mode='min'
)
model = MNISTModel()
trainer = Trainer(
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
)
trainer.fit(model)
```
Inspired by [PyTorch Lightning](https://github.com/williamFalcon/pytorch-lightning)