An open API service indexing awesome lists of open source software.

https://github.com/eguidotti/torchabc

A simple abstract class for training and inference in PyTorch
https://github.com/eguidotti/torchabc

pytorch torch

Last synced: 3 months ago
JSON representation

A simple abstract class for training and inference in PyTorch

Awesome Lists containing this project

README

          

# TorchABC

`torchabc` is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized.

The core of the package is the `TorchABC` class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.

This package has no extra dependencies beyond PyTorch and it consists of a simple self-contained [file](https://github.com/eguidotti/torchabc/blob/main/torchabc/__init__.py). It is ideal for research, prototyping, and teaching.

## Structure

The `TorchABC` class structures a project into the following main steps:

![diagram](https://github.com/user-attachments/assets/dd5abbb4-c28b-4477-a196-6eef5ad2ec2e)

1. **Dataloaders** - load raw data samples.
2. **Preprocess** – transform raw samples.
3. **Collate** - batch preprocessed samples.
4. **Network** - compute model outputs.
5. **Loss** - compute error against targets.
6. **Optimizer** - update model parameters.
7. **Postprocess** - transform outputs into predictions.

Each step corresponds to an abstract method in `TorchABC`. To use `TorchABC`, create a concrete subclass and implement these methods.

## Quick start

Install the package.

```bash
pip install torchabc
```

Generate a template using the command line interface.

```bash
torchabc --create template.py --min
```

Fill out the template by implementing the methods below. The documentation of each method is available [here](https://github.com/eguidotti/torchabc/blob/main/torchabc/__init__.py).

```py
import torch
from torchabc import TorchABC
from functools import cached_property

class MyModel(TorchABC):

@cached_property
def dataloaders(self):
raise NotImplementedError

@staticmethod
def preprocess(sample, hparams, flag=''):
return sample

@staticmethod
def collate(samples):
return torch.utils.data.default_collate(samples)

@cached_property
def network(self):
raise NotImplementedError

@staticmethod
def loss(outputs, targets, hparams):
raise NotImplementedError

@cached_property
def optimizer(self):
raise NotImplementedError

@staticmethod
def postprocess(outputs, hparams):
return outputs

```

## Usage

Once a subclass of `TorchABC` is implemented, it can be used for training, evaluation, checkpointing, and inference.

### Initialization

```python
model = MyModel()
```

Initialize the model.

### Training

```python
model.train(epochs=5, on="train", val="val")
```

Train the model for 5 epochs using the `train` and `val` dataloaders.

### Evaluation

```python
metrics = model.eval(on="test")
```

Evaluate on the `test` dataloader and return metrics.

### Checkpoints

```python
model.save("checkpoint.pth")
model.load("checkpoint.pth")
```

Save and restore the model state.

### Inference

```python
preds = model(samples)
```

Run predictions on raw input samples.

# API Reference

The `TorchABC` class defines a standard workflow for PyTorch projects. Some methods are [abstract](https://github.com/eguidotti/torchabc/tree/main?tab=readme-ov-file#abstract-methods) (must be implemented in subclasses), others are [optional](https://github.com/eguidotti/torchabc/tree/main?tab=readme-ov-file#default-methods) (can be overridden but have defaults), and a few are [concrete](https://github.com/eguidotti/torchabc/tree/main?tab=readme-ov-file#concrete-methods) (should not be overridden).

---

## Abstract Methods

| Method | Description |
| -------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `dataloaders` | Must return `dict[str, torch.utils.data.DataLoader]`. Example keys: `"train"`, `"val"`, `"test"`. |
| `preprocess(sample, hparams, flag='')` | Transform a raw dataset sample.
**Parameters:**
- `sample` (`Any`): raw sample.
- `hparams` (`dict`): hyperparameters.
- `flag` (`str`, optional): mode flag.
**Returns:** `Tensor` or iterable of tensors. |
| `collate(samples)` | Collate a batch of preprocessed samples.
**Parameters:**
- `samples` (`Iterable[Tensor]`)
**Returns:** `Tensor` or iterable of tensors. |
| `network` | Must return a `torch.nn.Module`. Inputs and outputs must use `(batch_size, ...)` format. |
| `optimizer` | Must return a `torch.optim.Optimizer` for `self.network.parameters()`. |
| `loss(outputs, targets, hparams)` | Compute loss for a batch.
**Parameters:**
- `outputs` (`Tensor` or iterable)
- `targets` (`Tensor` or iterable)
- `hparams` (`dict`)
**Returns:** `dict[str, Any]` containing key `"loss"`. |
| `postprocess(outputs, hparams)` | Convert network outputs into predictions.
**Parameters:**
- `outputs` (`Tensor` or iterable)
- `hparams` (`dict`)
**Returns:** predictions (`Any`). |

---

## Default Methods

| Method | Description |
| ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `scheduler` | Learning rate scheduler. May return `None`, `torch.optim.lr_scheduler.LRScheduler`, or `ReduceLROnPlateau`. Default is `None`. |
| `backward(batch, gas)` | Backpropagation step.
**Parameters:**
- `batch` (`dict[str, Any]`): must contain key `"loss"`.
- `gas` (`int`): gradient accumulation steps. |
| `metrics(batches, hparams)` | Compute evaluation metrics.
**Parameters:**
- `batches` (`deque[dict[str, Any]]`): batch results.
- `hparams` (`dict`)
**Returns:** `dict[str, Any]`. Default computes average loss. |
| `checkpoint(epoch, metrics, out)` | Checkpoint step. Saves model if loss improves.
**Parameters:**
- `epoch` (`int`): epoch number.
- `metrics` (`dict[str, float]`): validation metrics.
- `out` (`str` or `None`): output path to save checkpoints.
**Returns:** `bool` indicating early stopping.|
| `move(data)` | Move data to current device. Supports `Tensor`, list, tuple, dict. |
| `detach(data)` | Detach data from computation graph. Supports `Tensor`, list, tuple, dict. |

---

## Concrete Methods

| Method | Description |
| ------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `TorchABC(device=None, logger=print, hparams=None, **kwargs)` | Initialize the model.
**Parameters:**
- `device` (`str` or `torch.device`, optional): computation device. Defaults to CUDA if available, otherwise MPS or CPU.
- `logger` (`Callable[[dict], None]`, optional): logging function. Defaults to `print`.
- `hparams` (`dict`, optional): dictionary of hyperparameters.
- `kwargs`: additional attributes stored in the instance. |
| `train(epochs, gas=1, mas=None, on='train', val='val', out=None)` | Train the model.
**Parameters:**
- `epochs` (`int`): number of training epochs.
- `gas` (`int`, optional): gradient accumulation steps. Defaults to 1.
- `mas` (`int`, optional): metrics accumulation steps. Defaults to `gas`.
- `on` (`str`, optional): training dataloader name. Default `"train"`.
- `val` (`str`, optional): validation dataloader name. Default `"val"`. If `None`, validation is skipped.
- `out` (`str`, optional): output path to save checkpoints. |
| `eval(on)` | Evaluate the model.
**Parameters:**
- `on` (`str`): dataloader name.
**Returns:** `dict[str, float]` of evaluation metrics. |
| `__call__(samples)` | Run inference on raw samples.
**Parameters:**
- `samples` (`Iterable[Any]`): raw samples.
**Returns:** postprocessed predictions. |
| `save(path)` | Save a checkpoint.
**Parameters:**
- `path` (`str`): file path. |
| `load(path)` | Load a checkpoint.
**Parameters:**
- `path` (`str`): file path. |

---

## Examples

Get started with simple self-contained examples:

- [MNIST classification](https://github.com/eguidotti/torchabc/blob/main/examples/mnist.py)

### Run the examples

Install the dependencies

```
poetry install --with examples
```

Run the examples by replacing `` with one of the filenames in the [examples](https://github.com/eguidotti/torchabc/tree/main/examples) folder

```
poetry run python examples/.py
```

## Contribute

Contributions are welcome! Submit pull requests with new [examples](https://github.com/eguidotti/torchabc/tree/main/examples) or improvements to the core [`TorchABC`](https://github.com/eguidotti/torchabc/blob/main/torchabc/__init__.py) class itself.