{"id":15290466,"url":"https://github.com/time1ess/torchtools","last_synced_at":"2025-04-13T10:11:47.866Z","repository":{"id":57476566,"uuid":"102616366","full_name":"Time1ess/torchtools","owner":"Time1ess","description":"A High-Level training API on top of PyTorch","archived":false,"fork":false,"pushed_at":"2018-04-30T13:51:32.000Z","size":4679,"stargazers_count":17,"open_issues_count":0,"forks_count":1,"subscribers_count":2,"default_branch":"master","last_synced_at":"2025-03-27T01:35:49.658Z","etag":null,"topics":["keras","pytorch"],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"bsd-3-clause","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/Time1ess.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"docs/contributing.html","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2017-09-06T14:07:07.000Z","updated_at":"2024-02-29T04:50:03.000Z","dependencies_parsed_at":"2022-09-12T14:50:30.726Z","dependency_job_id":null,"html_url":"https://github.com/Time1ess/torchtools","commit_stats":null,"previous_names":[],"tags_count":2,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Time1ess%2Ftorchtools","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Time1ess%2Ftorchtools/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Time1ess%2Ftorchtools/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Time1ess%2Ftorchtools/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/Time1ess","download_url":"https://codeload.github.com/Time1ess/torchtools/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248695481,"owners_count":21146956,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["keras","pytorch"],"created_at":"2024-09-30T16:08:17.418Z","updated_at":"2025-04-13T10:11:47.829Z","avatar_url":"https://github.com/Time1ess.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# torchtools: A High-Level training API on top of PyTorch\n\n[![Build Status](https://travis-ci.org/Time1ess/torchtools.svg?branch=master)](https://travis-ci.org/Time1ess/torchtools)\n[![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://github.com/Time1ess/torchtools/blob/master/LICENSE)\n[![Docs](https://img.shields.io/badge/docs-link-green.svg)](https://Time1ess.github.io/torchtools)\n\n---\n\ntorchtools is a High-Level training API on top of [PyTorch](http://pytorch.org) with many useful features to simplifiy the traing process for users.\n\nIt was developed based on ideas from [tnt](https://github.com/pytorch/tnt), [Keras](https://github.com/fchollet/keras). I wrote this tool just want to release myself, since many different training tasks share same training routine(define dataset, retrieve a batch of samples, forward propagation, backward propagation, ...).\n\nThis API provides these follows:\n\n* A high-level training class named `ModelTrainer`. No need to repeat yourself.\n* A bunch of useful `callbacks` to inject your code in any stages during the training.\n* A set of `meters` to get the performance of your model.\n* Visualization in TensorBoard support(TensorBoard required).\n\n## Requirements\n\n* tqdm\n* Numpy\n* [PyTorch v0.4.0+](http://pytorch.org)\n* [tensorboardX](https://github.com/lanpa/tensorboard-pytorch)\n* [Standalone TensorBoard](https://github.com/dmlc/tensorboard)(Optional)\n\n## Installation\n\ntorchtools has been tested on **Python 2.7+**, **Python 3.5+**.\n\n`pip install torchtools`\n\n## Screenshots\n\nTraining Process:\n\n![](training_process.gif)\n\nVisualization in TensorBoard:\n\n![](visualization_in_tensorboard.png)\n\n\n## 1 Minute torchtools MNIST example\n\n```Python\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nimport torchvision.transforms as T\n\nfrom torch.utils.data import DataLoader\nfrom torch.nn.init import xavier_uniform as xavier\nfrom torchvision.datasets import MNIST\n\nfrom torchtools.trainer import Trainer\nfrom torchtools.meters import LossMeter, AccuracyMeter\nfrom torchtools.callbacks import (\n    StepLR, ReduceLROnPlateau, TensorBoardLogger, CSVLogger)\n\n\nEPOCHS = 10\nBATCH_SIZE = 32\nDATASET_DIRECTORY = 'dataset'\n\ntrainset = MNIST(root=DATASET_DIRECTORY, transform=T.ToTensor())\ntestset = MNIST(root=DATASET_DIRECTORY, train=False, transform=T.ToTensor())\n\ntrain_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)\ntest_loader = DataLoader(testset, batch_size=BATCH_SIZE)\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.fc1 = nn.Linear(28 * 28, 100)\n        self.fc2 = nn.Linear(100, 10)\n\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                xavier(m.weight.data)\n\n    def forward(self, x):\n        x = x.view(-1, 28 * 28)\n        x = F.relu(self.fc1(x))\n        x = self.fc2(x)\n        return x\n\n\nmodel = Net()\noptimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)\ncriterion = nn.CrossEntropyLoss()\n\ntrainer = Trainer(model, train_loader, criterion, optimizer, test_loader)\n\n# Callbacks\nloss = LossMeter('loss')\nval_loss = LossMeter('val_loss')\nacc = AccuracyMeter('acc')\nval_acc = AccuracyMeter('val_acc')\nscheduler = StepLR(optimizer, 1, gamma=0.95)\nreduce_lr = ReduceLROnPlateau(optimizer, 'val_loss', factor=0.3, patience=3)\nlogger = TensorBoardLogger()\ncsv_logger = CSVLogger(keys=['epochs', 'loss', 'acc', 'val_loss', 'val_acc'])\n\ntrainer.register_hooks([\n    loss, val_loss, acc, val_acc, scheduler, reduce_lr, logger, csv_logger])\n\ntrainer.train(EPOCHS)\n```\n\n### Callbacks\n\n`callbacks` provides samilar API compared with [Keras](https://github.com/fchollet/keras). We can have more control on our training process through `callbacks`.\n\n```Python\nfrom torchtools.callbacks import StepLR, ReduceLROnPlateau, TensorBoardLogger\n\nscheduler = StepLR(optimizer, 1, gamma=0.95)\nreduce_lr = ReduceLROnPlateau(optimizer, 'val_loss', factor=0.3, patience=3)\nlogger = TensorBoardLogger(comment=name)\n\n...\n\ntrainer.register_hooks([scheduler, reduce_lr, logger])\n```\n\n### Meters\n\n`meters` are provided to measure `loss`, `accuracy`, `time` in different ways.\n\n```Python\nfrom torchtools.meters import LossMeter, AccuracyMeter\n\nloss_meter = LossMeter('loss')\nval_loss_meter = LossMeter('val_loss'))\nacc_meter = AccuracyMeter('acc')\n```\n\n### Put together\n\nNow, we can put it together.\n\n1. Instantiate a `Trainer` object with `Model`, `Dataloader for trainset`, `Criterion`, `Optimizer`, and other optional arguments.\n2. All `callbacks` and `meters` are actually `Hook` objects, so we can use `register_hooks` to register these hooks to `ModelTrainer`.\n3. Call `.train(epochs)` on `Trainer` with training epochs.\n4. Done!\n\n## Contributing\n\nPlease feel free to add more features!\n\nIf there are any bugs or feature requests please [submit an issue](https://github.com/Time1ess/torchtools/issues/new), I'll see what I can do.\n\nAny new features or bug fixes please submit a PR in [Pull requests](https://github.com/Time1ess/torchtools/pulls).\n\nIf there are any other problems, please email: \u003ca href=\"mailto:youche.du@gmail.com\"\u003eyouchen.du@gmail.com\u003c/a\u003e\n\n## Acknowledgement\n\nThanks to these people and groups:\n\n* All PyTorch developers\n* All PyTorchNet developers\n* All Keras developers\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftime1ess%2Ftorchtools","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ftime1ess%2Ftorchtools","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftime1ess%2Ftorchtools/lists"}