https://github.com/paraglondhe098/torchtrainer
PyTorch Training Framework: A customizable PyTorch training loop class with support for metrics tracking, early stopping, and callbacks. Includes methods for multi-class and binary accuracy, precision, recall, and R² score calculations.
https://github.com/paraglondhe098/torchtrainer
Last synced: 6 months ago
JSON representation
PyTorch Training Framework: A customizable PyTorch training loop class with support for metrics tracking, early stopping, and callbacks. Includes methods for multi-class and binary accuracy, precision, recall, and R² score calculations.
- Host: GitHub
- URL: https://github.com/paraglondhe098/torchtrainer
- Owner: paraglondhe098
- Created: 2024-08-12T18:16:18.000Z (almost 2 years ago)
- Default Branch: master
- Last Pushed: 2024-12-10T19:10:24.000Z (over 1 year ago)
- Last Synced: 2024-12-29T10:19:34.937Z (over 1 year ago)
- Language: Python
- Size: 63.5 KB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: readme.md
Awesome Lists containing this project
README
# PyTorch Training Framework
This repository provides a versatile PyTorch training framework to simplify and enhance the model training process. It includes support for metrics tracking, early stopping, and customizable callbacks.
## Features
- **Metrics Tracking:** Calculate and monitor multi-class and binary accuracy, precision, recall, and R² score.
- **Custom Callbacks:** Implement and use custom callbacks for various training events.
- **Early Stopping:** Automatically halt training based on validation loss to avoid overfitting.
- **Mixed Precision Training:** Utilize mixed precision for improved performance on CUDA-enabled GPUs.
- **Detailed Reporting:** Get clear and comprehensive reports of training and validation metrics.
## Installation
1. Clone the repository:
```bash
git clone https://github.com/paraglondhe098/torchtrainer.git
cd torchtrainer
```
2. Install the required dependencies:
```bash
pip install -r requirements.txt
```
## Usage
Here is a basic example of how to use the `Trainer` class:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchtrainer import Trainer, IntraEpochReport, EarlyStopping
# Define model, criterion, and optimizer
model = nn.Sequential(nn.Linear(10, 1))
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Create the Trainer instance
trainer = Trainer(
model=model,
epochs=10,
criterion=criterion,
optimizer=optimizer,
metrics=['accuracy'],
mixed_precision_training=True
)
# Add callbacks
trainer.add_callback(IntraEpochReport(reports_per_epoch=10))
trainer.add_callback(EarlyStopping(basis='vloss', patience=3))
# Prepare your data loaders
train_loader = ...
val_loader = ...
# Train the model
trainer.fit(train_loader, val_loader)
```