{"id":27373387,"url":"https://github.com/eguidotti/torchabc","last_synced_at":"2026-03-08T23:31:12.253Z","repository":{"id":287156872,"uuid":"962632325","full_name":"eguidotti/torchabc","owner":"eguidotti","description":"A simple abstract class for training and inference in PyTorch","archived":false,"fork":false,"pushed_at":"2025-09-28T17:56:05.000Z","size":103,"stargazers_count":4,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2026-03-05T05:29:21.830Z","etag":null,"topics":["pytorch","torch"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/eguidotti.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2025-04-08T12:52:03.000Z","updated_at":"2025-09-28T17:56:08.000Z","dependencies_parsed_at":"2025-09-14T19:22:47.235Z","dependency_job_id":"a10a4611-a5bd-489b-8eae-067ae8ff7534","html_url":"https://github.com/eguidotti/torchabc","commit_stats":null,"previous_names":["eguidotti/torchabc"],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/eguidotti/torchabc","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/eguidotti%2Ftorchabc","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/eguidotti%2Ftorchabc/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/eguidotti%2Ftorchabc/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/eguidotti%2Ftorchabc/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/eguidotti","download_url":"https://codeload.github.com/eguidotti/torchabc/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/eguidotti%2Ftorchabc/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":30276904,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-03-08T20:45:49.896Z","status":"ssl_error","status_checked_at":"2026-03-08T20:45:49.525Z","response_time":56,"last_error":"SSL_connect returned=1 errno=0 peeraddr=140.82.121.6:443 state=error: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["pytorch","torch"],"created_at":"2025-04-13T11:14:34.158Z","updated_at":"2026-03-08T23:31:12.244Z","avatar_url":"https://github.com/eguidotti.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# TorchABC\n\n`torchabc` is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized. \n\nThe core of the package is the `TorchABC` class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.\n\nThis package has no extra dependencies beyond PyTorch and it consists of a simple self-contained [file](https://github.com/eguidotti/torchabc/blob/main/torchabc/__init__.py). It is ideal for research, prototyping, and teaching.\n\n## Structure\n\nThe `TorchABC` class structures a project into the following main steps:\n\n![diagram](https://github.com/user-attachments/assets/dd5abbb4-c28b-4477-a196-6eef5ad2ec2e)\n\n1. **Dataloaders** - load raw data samples.\n2. **Preprocess** – transform raw samples.\n3. **Collate** - batch preprocessed samples.\n4. **Network** - compute model outputs.\n5. **Loss** - compute error against targets.\n6. **Optimizer** - update model parameters.\n7. **Postprocess** - transform outputs into predictions.\n\nEach step corresponds to an abstract method in `TorchABC`. To use `TorchABC`, create a concrete subclass and implement these methods.\n\n## Quick start\n\nInstall the package.\n\n```bash\npip install torchabc\n```\n\nGenerate a template using the command line interface.\n\n```bash\ntorchabc --create template.py --min\n```\n\nFill out the template by implementing the methods below. The documentation of each method is available [here](https://github.com/eguidotti/torchabc/blob/main/torchabc/__init__.py).\n\n```py\nimport torch\nfrom torchabc import TorchABC\nfrom functools import cached_property\n\n\nclass MyModel(TorchABC):\n    \n    @cached_property\n    def dataloaders(self):\n        raise NotImplementedError\n    \n    @staticmethod\n    def preprocess(sample, hparams, flag=''):\n        return sample\n\n    @staticmethod\n    def collate(samples):\n        return torch.utils.data.default_collate(samples)\n\n    @cached_property\n    def network(self):\n        raise NotImplementedError\n    \n    @staticmethod\n    def loss(outputs, targets, hparams):\n        raise NotImplementedError\n\n    @cached_property\n    def optimizer(self):\n        raise NotImplementedError\n    \n    @staticmethod\n    def postprocess(outputs, hparams):\n        return outputs\n\n```\n\n## Usage\n\nOnce a subclass of `TorchABC` is implemented, it can be used for training, evaluation, checkpointing, and inference.\n\n### Initialization\n\n```python\nmodel = MyModel()\n```\n\nInitialize the model.\n\n### Training\n\n```python\nmodel.train(epochs=5, on=\"train\", val=\"val\")\n```\n\nTrain the model for 5 epochs using the `train` and `val` dataloaders.\n\n### Evaluation\n\n```python\nmetrics = model.eval(on=\"test\")\n```\n\nEvaluate on the `test` dataloader and return metrics.\n\n### Checkpoints\n\n```python\nmodel.save(\"checkpoint.pth\")\nmodel.load(\"checkpoint.pth\")\n```\n\nSave and restore the model state.\n\n### Inference\n\n```python\npreds = model(samples)\n```\n\nRun predictions on raw input samples.\n\n# API Reference\n\nThe `TorchABC` class defines a standard workflow for PyTorch projects. Some methods are [abstract](https://github.com/eguidotti/torchabc/tree/main?tab=readme-ov-file#abstract-methods) (must be implemented in subclasses), others are [optional](https://github.com/eguidotti/torchabc/tree/main?tab=readme-ov-file#default-methods) (can be overridden but have defaults), and a few are [concrete](https://github.com/eguidotti/torchabc/tree/main?tab=readme-ov-file#concrete-methods) (should not be overridden).\n\n---\n\n## Abstract Methods\n\n| Method                                 | Description                                                                                                                                                                                                                     |\n| -------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| `dataloaders`               | Must return `dict[str, torch.utils.data.DataLoader]`. Example keys: `\"train\"`, `\"val\"`, `\"test\"`.                                                                                                                               |\n| `preprocess(sample, hparams, flag='')` | Transform a raw dataset sample.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `sample` (`Any`): raw sample.\u003cbr\u003e - `hparams` (`dict`): hyperparameters.\u003cbr\u003e - `flag` (`str`, optional): mode flag.\u003cbr\u003e **Returns:** `Tensor` or iterable of tensors. |\n| `collate(samples)`                     | Collate a batch of preprocessed samples.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `samples` (`Iterable[Tensor]`)\u003cbr\u003e **Returns:** `Tensor` or iterable of tensors.                                                                             |\n| `network`                   | Must return a `torch.nn.Module`. Inputs and outputs must use `(batch_size, ...)` format.                                                                                                                                        |\n| `optimizer`                 | Must return a `torch.optim.Optimizer` for `self.network.parameters()`.                                                                                                                                                          |\n| `loss(outputs, targets, hparams)`      | Compute loss for a batch.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `outputs` (`Tensor` or iterable)\u003cbr\u003e - `targets` (`Tensor` or iterable)\u003cbr\u003e - `hparams` (`dict`)\u003cbr\u003e **Returns:** `dict[str, Any]` containing key `\"loss\"`.                 |\n| `postprocess(outputs, hparams)`        | Convert network outputs into predictions.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `outputs` (`Tensor` or iterable)\u003cbr\u003e - `hparams` (`dict`)\u003cbr\u003e **Returns:** predictions (`Any`).                                                             |\n\n---\n\n## Default Methods\n\n| Method                             | Description                                                                                                                                                                                                                                                         |\n| ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| `scheduler`             | Learning rate scheduler. May return `None`, `torch.optim.lr_scheduler.LRScheduler`, or `ReduceLROnPlateau`. Default is `None`.                                                                                                                                      |\n| `backward(batch, gas)`             | Backpropagation step.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `batch` (`dict[str, Any]`): must contain key `\"loss\"`.\u003cbr\u003e - `gas` (`int`): gradient accumulation steps.                                                                                                                |\n| `metrics(batches, hparams)`        | Compute evaluation metrics.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `batches` (`deque[dict[str, Any]]`): batch results.\u003cbr\u003e - `hparams` (`dict`)\u003cbr\u003e **Returns:** `dict[str, Any]`. Default computes average loss.                                                                |\n| `checkpoint(epoch, metrics, out)` | Checkpoint step. Saves model if loss improves.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `epoch` (`int`): epoch number.\u003cbr\u003e - `metrics` (`dict[str, float]`): validation metrics.\u003cbr\u003e - `out` (`str` or `None`): output path to save checkpoints.\u003cbr\u003e **Returns:** `bool` indicating early stopping.|\n| `move(data)`                       | Move data to current device. Supports `Tensor`, list, tuple, dict.                                                                                                                                                                                                  |\n| `detach(data)`                     | Detach data from computation graph. Supports `Tensor`, list, tuple, dict.                                                                                                                                                                                           |\n\n---\n\n## Concrete Methods\n\n| Method                                                                   | Description                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       |\n| ------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| `TorchABC(device=None, logger=print, hparams=None, **kwargs)` | Initialize the model.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `device` (`str` or `torch.device`, optional): computation device. Defaults to CUDA if available, otherwise MPS or CPU.\u003cbr\u003e - `logger` (`Callable[[dict], None]`, optional): logging function. Defaults to `print`.\u003cbr\u003e - `hparams` (`dict`, optional): dictionary of hyperparameters.\u003cbr\u003e - `kwargs`: additional attributes stored in the instance. |\n| `train(epochs, gas=1, mas=None, on='train', val='val', out=None)` | Train the model.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `epochs` (`int`): number of training epochs.\u003cbr\u003e - `gas` (`int`, optional): gradient accumulation steps. Defaults to 1.\u003cbr\u003e - `mas` (`int`, optional): metrics accumulation steps. Defaults to `gas`.\u003cbr\u003e - `on` (`str`, optional): training dataloader name. Default `\"train\"`.\u003cbr\u003e - `val` (`str`, optional): validation dataloader name. Default `\"val\"`. If `None`, validation is skipped.\u003cbr\u003e - `out` (`str`, optional): output path to save checkpoints. |\n| `eval(on)`                                                               | Evaluate the model.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `on` (`str`): dataloader name.\u003cbr\u003e **Returns:** `dict[str, float]` of evaluation metrics.                                                                                                                                                                                                                                                                                                                                                                                           |\n| `__call__(samples)`                                                      | Run inference on raw samples.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `samples` (`Iterable[Any]`): raw samples.\u003cbr\u003e **Returns:** postprocessed predictions.                                                                                                                                                                                                                                                                                                                                                                                     |\n| `save(path)`                                                             | Save a checkpoint.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `path` (`str`): file path.                                                                                                                                                                                                                                                                                                                                                                                                                                                           |\n| `load(path)`                                                             | Load a checkpoint.\u003cbr\u003e **Parameters:**\u003cbr\u003e - `path` (`str`): file path.                                                                                                                                                                                                                                                                                                                                                                                                                                                           |\n\n---\n\n## Examples\n\nGet started with simple self-contained examples:\n\n- [MNIST classification](https://github.com/eguidotti/torchabc/blob/main/examples/mnist.py)\n\n### Run the examples\n\nInstall the dependencies\n\n```\npoetry install --with examples\n```\n\nRun the examples by replacing `\u003cname\u003e` with one of the filenames in the [examples](https://github.com/eguidotti/torchabc/tree/main/examples) folder\n\n```\npoetry run python examples/\u003cname\u003e.py\n```\n\n## Contribute\n\nContributions are welcome! Submit pull requests with new [examples](https://github.com/eguidotti/torchabc/tree/main/examples) or improvements to the core [`TorchABC`](https://github.com/eguidotti/torchabc/blob/main/torchabc/__init__.py) class itself. \n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Feguidotti%2Ftorchabc","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Feguidotti%2Ftorchabc","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Feguidotti%2Ftorchabc/lists"}