{"id":17842874,"url":"https://github.com/michalwols/yann","last_synced_at":"2025-03-20T04:31:46.816Z","repository":{"id":41465460,"uuid":"121509738","full_name":"michalwols/yann","owner":"michalwols","description":"Yet Another Neural Network Library 🤔","archived":false,"fork":false,"pushed_at":"2024-03-23T02:47:47.000Z","size":1358,"stargazers_count":24,"open_issues_count":14,"forks_count":5,"subscribers_count":3,"default_branch":"master","last_synced_at":"2024-04-23T11:53:55.302Z","etag":null,"topics":["deep-learning","neural-network","nn","python","pytorch","torch"],"latest_commit_sha":null,"homepage":"https://michalwols.github.io/yann/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/michalwols.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2018-02-14T12:44:25.000Z","updated_at":"2024-04-23T11:53:55.302Z","dependencies_parsed_at":"2024-01-11T23:37:25.484Z","dependency_job_id":"b2df20a9-48ed-4dc0-bfac-c062f310cbb8","html_url":"https://github.com/michalwols/yann","commit_stats":{"total_commits":246,"total_committers":3,"mean_commits":82.0,"dds":0.06504065040650409,"last_synced_commit":"3a3da48e3628c34d50c753350371b05581bdaf1f"},"previous_names":[],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/michalwols%2Fyann","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/michalwols%2Fyann/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/michalwols%2Fyann/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/michalwols%2Fyann/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/michalwols","download_url":"https://codeload.github.com/michalwols/yann/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":244058147,"owners_count":20391046,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","neural-network","nn","python","pytorch","torch"],"created_at":"2024-10-27T21:18:39.272Z","updated_at":"2025-03-20T04:31:46.339Z","avatar_url":"https://github.com/michalwols.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\n# yann (Yet Another Neural Network Library)\n\n Yann is an extended version of torch.nn, adding a ton of sugar to make training models as fast and easy as possible.\n\n## Getting Started\n\n### Install \n\n```shell script\npip install yann\n```\n\n\n### Train LeNet on MNIST\n\n```python\nimport torch\nfrom torch import nn\nfrom torchvision import transforms\n\nimport yann\nfrom yann.train import Trainer\nfrom yann.modules import Stack, Flatten, Infer\nfrom yann.params import HyperParams, Choice, Range\n\n\nclass Params(HyperParams):\n  dataset = 'MNIST'\n  batch_size = 32\n  epochs = 10\n  optimizer: Choice(('SGD', 'Adam')) = 'SGD'\n  learning_rate: Range(.01, .0001) = .01\n  momentum = 0\n\n  seed = 1\n\n# parse command line arguments\nparams = Params.from_command()\n\n# set random, numpy and pytorch seeds in one call\nyann.seed(params.seed)\n\nlenet = Stack(\n  Infer(nn.Conv2d, 10, kernel_size=5),\n  nn.MaxPool2d(2),\n  nn.ReLU(inplace=True),\n  Infer(nn.Conv2d, 20, kernel_size=5),\n  nn.MaxPool2d(2),\n  nn.ReLU(inplace=True),\n  Flatten(),\n  Infer(nn.Linear, 50),\n  nn.ReLU(inplace=True),\n  Infer(nn.Linear, 10),\n  activation=nn.LogSoftmax(dim=1)\n)\n\n# run a forward pass to infer input shapes using `Infer` modules\nlenet(torch.rand(1, 1, 28, 28))\n\n# use the registry to resolve optimizer name to an optimizer class\noptimizer = yann.resolve.optimizer(\n  params.optimizer,\n  yann.trainable(lenet.parameters()),\n  momentum=params.momentum,\n  lr=params.learning_rate\n)\n\ntrain = Trainer(\n  model=lenet,\n  optimizer=optimizer,\n  dataset=params.dataset,\n  batch_size=params.batch_size,\n  transform=transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Normalize((0.1307,), (0.3081,))\n  ]),\n  loss='nll_loss',\n  metrics=('accuracy',)\n)\n\ntrain(params.epochs)\n\n# save checkpoint\ntrain.checkpoint()\n\n# plot the loss curve\ntrain.history.plot()\n```\n\nview the generated cli help\n```bash\npython train.py -h\n```\n\n```shell script\n-h\nusage: train_mnist.py [-h] [-o {SGD,Adam}] [-lr LEARNING_RATE] [-d DATASET]\n                      [-bs BATCH_SIZE] [-e EPOCHS] [-m MOMENTUM] [-s SEED]\n\noptional arguments:\n  -h, --help            show this help message and exit\n  -o {SGD,Adam}, --optimizer {SGD,Adam}\n                        optimizer (default: SGD)\n  -lr LEARNING_RATE, --learning_rate LEARNING_RATE\n                        learning_rate (default: 0.01)\n  -d DATASET, --dataset DATASET\n                        dataset (default: MNIST)\n  -bs BATCH_SIZE, --batch_size BATCH_SIZE\n                        batch_size (default: 32)\n  -e EPOCHS, --epochs EPOCHS\n                        epochs (default: 10)\n  -m MOMENTUM, --momentum MOMENTUM\n                        momentum (default: 0)\n  -s SEED, --seed SEED  seed (default: 1)\n```\n\nthen start a training run\n\n```shell script\npython train.py -bs=16\n```\n\nwhich should print the following to stdout\n\n```less\nParams(\n  optimizer=SGD,\n  learning_rate=0.01,\n  dataset=MNIST,\n  batch_size=16,\n  epochs=10,\n  momentum=0,\n  seed=1\n)\nStarting training\n\nname: MNIST-Stack\nroot: train-runs/MNIST-Stack/19-09-25T18:02:52\nbatch_size: 16\ndevice: cpu\n\nMODEL\n=====\n\nStack(\n  (infer0): Infer(\n    (module): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n  )\n  (max_pool2d0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  (re_lu0): ReLU(inplace=True)\n  (infer1): Infer(\n    (module): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))\n  )\n  (max_pool2d1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n  (re_lu1): ReLU(inplace=True)\n  (flatten0): Flatten()\n  (infer2): Infer(\n    (module): Linear(in_features=320, out_features=50, bias=True)\n  )\n  (re_lu2): ReLU(inplace=True)\n  (infer3): Infer(\n    (module): Linear(in_features=50, out_features=10, bias=True)\n  )\n  (activation): LogSoftmax()\n)\n\n\nDATASET\n=======\n\nTransformDataset(\nDataset: Dataset MNIST\n    Number of datapoints: 60000\n    Root location: /Users/michal/.torch/datasets/MNIST\n    Split: Train\nTransforms: (Compose(\n    ToTensor()\n    Normalize(mean=(0.1307,), std=(0.3081,))\n),)\n)\n\n\nLOADER\n======\n\n\u003ctorch.utils.data.dataloader.DataLoader object at 0x1a45cc8940\u003e\n\nLOSS\n====\n\n\u003cfunction nll_loss at 0x120b700d0\u003e\n\n\nOPTIMIZER\n=========\n\nSGD (\nParameter Group 0\n    dampening: 0\n    lr: 0.01\n    momentum: 0\n    nesterov: False\n    weight_decay: 0.0001\n)\n\nSCHEDULER\n=========\n\nNone\n\n\nPROGRESS\n========\nepochs: 0\nsteps: 0\nsamples: 0\n\n\nStarting epoch 0\n\nOPTIMIZER\n=========\n\nSGD (\nParameter Group 0\n    dampening: 0\n    lr: 0.01\n    momentum: 0\n    nesterov: False\n    weight_decay: 0.0001\n)\n\n\nPROGRESS\n========\nepochs: 0\nsteps: 0\nsamples: 0\n\n\nBatch inputs shape: (16, 1, 28, 28)\nBatch targets shape: (16,)\nBatch outputs shape: (16, 10)\n\nbatch:        0\taccuracy: 0.1875\tloss: 2.3783\nbatch:      128\taccuracy: 0.6250\tloss: 2.0528\nbatch:      256\taccuracy: 0.6875\tloss: 0.6222\n```","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmichalwols%2Fyann","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmichalwols%2Fyann","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmichalwols%2Fyann/lists"}