{"id":15426950,"url":"https://github.com/juharris/train-pytorch-in-js","last_synced_at":"2025-08-11T18:08:56.277Z","repository":{"id":39003086,"uuid":"464048479","full_name":"juharris/train-pytorch-in-js","owner":"juharris","description":"Convert a PyTorch model and train it in JavaScript in your browser using ONNX Runtime Web","archived":false,"fork":false,"pushed_at":"2022-10-26T21:38:40.000Z","size":87534,"stargazers_count":15,"open_issues_count":2,"forks_count":3,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-04-19T18:54:16.943Z","etag":null,"topics":["onnx","onnxruntime","onnxruntime-web","pytorch"],"latest_commit_sha":null,"homepage":"https://juharris.github.io/train-pytorch-in-js","language":"JavaScript","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/juharris.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2022-02-27T05:40:06.000Z","updated_at":"2024-07-19T07:21:27.000Z","dependencies_parsed_at":"2023-01-19T17:04:00.003Z","dependency_job_id":null,"html_url":"https://github.com/juharris/train-pytorch-in-js","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/juharris/train-pytorch-in-js","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juharris%2Ftrain-pytorch-in-js","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juharris%2Ftrain-pytorch-in-js/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juharris%2Ftrain-pytorch-in-js/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juharris%2Ftrain-pytorch-in-js/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/juharris","download_url":"https://codeload.github.com/juharris/train-pytorch-in-js/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juharris%2Ftrain-pytorch-in-js/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":269931473,"owners_count":24498722,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","status":"online","status_checked_at":"2025-08-11T02:00:10.019Z","response_time":75,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["onnx","onnxruntime","onnxruntime-web","pytorch"],"created_at":"2024-10-01T17:58:24.649Z","updated_at":"2025-08-11T18:08:56.246Z","avatar_url":"https://github.com/juharris.png","language":"JavaScript","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Train PyTorch Models in JavaScript/TypeScript\nConvert a [PyTorch](https://https://pytorch.org) model and train it in JavaScript using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime/tree/master/js/web).\n\nTry it yourself at https://juharris.github.io/train-pytorch-in-js.\n\nExample:\n![example of how training looks in the browser](./assets/training.gif)\n\n# Overview\nSteps:\n\n0. Define your PyTorch model. You probably already did this.\n1. Use the new utility method to export an ONNX gradient graph for the model.\n2. Set up an optimizer graph.\n3. Load the graphs in JavaScript (this project uses TypeScript).\n4. Use the graphs to train the model.\n\nDetails:\n\n## 0. Define your PyTorch model\nYou probably already did this.\nHere's our simple example:\n```python\nimport torch\n\nclass MyModel(torch.nn.Module):\n    def __init__(self,\n                 input_size: int,\n                 hidden_size: int,\n                 num_classes: int):\n        super(MyModel, self).__init__()\n        self.fc1 = torch.nn.Linear(input_size, hidden_size)\n        self.relu = torch.nn.ReLU()\n        self.fc2 = torch.nn.Linear(hidden_size, num_classes)\n\n    def forward(self, x):\n        out = self.fc1(x)\n        out = self.relu(out)\n        out = self.fc2(out)\n        return out\n```\n\nYou can train it in Python to get some good initial weights but that's not required to export it and then train it in JavaScript.\n\n## 1. Export the model's gradient and optimizer graphs\nWe're going to create an ONNX graph that can compute gradients when given training data.\n\nYou can follow along here or see the full example in [example.py](./export/example.py) or [mnist/example.py](./export/mnist/example.py).\n\n### 1. Install some dependencies\n*I did this in Windows Subsystem for Linux (WSL).*\n\n* PyTorch\n\nIf you don't already have PyTorch installed, see [pytorch.org](https://pytorch.org/get-started/locally/) for how to install it on your system.\nFor example:\n```bash\nconda install pytorch torchvision torchaudio cpuonly -c pytorch\n```\n\n* ONNX Runtime\n\nSee [onnxruntime.ai](https://onnxruntime.ai) for all installation options.\nThe utility method we'll use is new in version 1.11 so you'll need at least that version.\nMake sure that the version you use it the same as the version of ONNX Runtime Web that you'll use later.\nThis repository includes a pre-built ONNX Runtime Web version for version 1.11 so we'll use that version for our Python onnxruntime dependencies.\n\nExample:\n```bash\npip install onnx 'onnxruntime==1.11.*' 'onnxruntime-training==1.11.*'\n```\n\n### 2. Export the model's gradient graph\n\n```python\nimport torch\nfrom onnxruntime.training.experimental import export_gradient_graph\n\n# We need a custom loss function to load the graph in an InferenceSession in ONNX Runtime Web.\n# You can still make the gradient graph with torch.nn.CrossEntropyLoss() and this part will work but you'll get problem later when trying to use the graph in JavaScript.\ndef binary_cross_entropy_loss(output, target):\n    return -torch.sum(target * torch.log2(output[:, 0]) +\n        (1-target) * torch.log2(output[:, 1]))\n\n\nloss_fn = binary_cross_entropy_loss\n\ninput_size = 10\nnum_classes = 2\nmodel = MyModel(input_size=input_size, hidden_size=5, num_classes=num_classes)\n\n# File path for where to save the ONNX graph.\ngradient_graph_path = 'gradient_graph.onnx'\n\n# We need example input for the ONNX model.\n# It doesn't matter what values are filled in the but the dimensions need to be correct.\nbatch_size = 32\nexample_input = torch.randn(\n    batch_size, input_size, requires_grad=True)\nexample_labels = torch.randint(0, num_classes, (batch_size,))\n\nexport_gradient_graph(\n    model, loss_fn, example_input, example_labels, gradient_graph_path)\n```\n\nYou now have an ONNX graph at `gradient_graph.onnx`.\nIf you want to validate it, see [orttraining_test_experimental_gradient_graph.py](https://github.com/microsoft/onnxruntime/blob/master/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py) for examples.\n\n### 3. Set up an optimizer and export it\nWe'll run another ONNX graph to compute the weight updates.\nThis repo has an example for an [Adam](https://arxiv.org/abs/1412.6980) optimizer [here](./export/optim/adam.py).\n\nThe optimizer graph is kept separate from the gradient graph for a few reasons:\n* You can easily swap the optimizer for a different optimizer while using the same gradient graph.\n* Historically, putting the model's gradient graph and the optimizer graph together was too complex to support many different types of optimizers.\n\nExport the optimizer graph:\n```python\nfrom optim.adam import AdamOnnxGraphBuilder\n\noptimizer = AdamOnnxGraphBuilder(model.named_parameters())\nonnx_optimizer = optimizer.export()\nonnx.save(onnx_optimizer, 'optimizer_graph.onnx')\n```\n\n### 4. MNIST Example\nThose were just examples that you could follow in your own project.\nThis browser example project will load a model that classifies digits from the [MNIST dataset][mnist].\n\nNext, we'll prepare the model's gradient graph and optimizer graph for the example JavaScript project.\nGo to the export folder:\n```bash\ncd export\n```\n\nTo export the MNIST example:\n```bash\npython -m mnist.example\n```\n\n(Optional) Train the model in Python to verify that it should work:\n```bash\npython -m mnist.train\n```\n\n## 2. Load the model in JavaScript\nWe'll use [ONNX Runtime Web](https://github.com/microsoft/onnxruntime/tree/master/js/web) to load the gradient graph.\n\nAt this time (May 2022), this only works with custom ONNX Runtine Web builds which have training operators enabled but the required files are included in this repository.\nThe officially published ONNX Runtime Web doesn't support the certain operators in our exported gradient graph with gradient calculations such as `GatherGrad` when using an InferenceSession.\n\n### 0. (Optional) Build ONNX Runtime Web with training operators enabled.\n\nFor your convenience, we included a build of ONNX Runtime Web with training operators enabled for ONNX Runtime version 1.11.\nYou can see other versions [here](https://github.com/microsoft/onnxruntime/releases).\n\nIf you would like to build it yourself, here's some commands that should help assuming you're using Linux and have CMake and `conda` setup:\n```bash\nconda create --name ort-dev python=3.8 numpy h5py\nconda activate ort-dev\nconda install -c anaconda libstdcxx-ng\nconda install pytorch torchvision torchaudio cpuonly -c pytorch\npip install flake8 pytest\n# This is a specific tag that should work, you can try with other versions but this tutorial will work best if the version matches the onnxruntime and onnxruntime-training versions you installed for Python earlier.\ncommit=\"2dfd81b9bb097c90388010e5b7d298498274f8d9\"\ngit clone --recursive git@github.com:microsoft/onnxruntime.git\ncd onnxruntime\ngit checkout ${commit}\ngit submodule update --init --recursive\npip install -r requirements-dev.txt\n```\n\nFor the build command, there are instructions at [ONNX Runtime Web](https://github.com/microsoft/onnxruntime/tree/master/js/web) which currently links to specific instructions [here](https://github.com/microsoft/onnxruntime/blob/master/js/README.md#Build-2).\nWhen you get to the \"Build ONNX Runtime WebAssembly\" step, you'll need to add `--enable_training --enable_training_ops` to the build command.\nFor example:\n```bash\n./build.sh --build_wasm --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests\n./build.sh --build_wasm --enable_wasm_simd --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests\n./build.sh --build_wasm --enable_wasm_threads --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests\n./build.sh --build_wasm --enable_wasm_simd --enable_wasm_threads --parallel $(expr `nproc` - 1) --enable_training --enable_training_ops --skip_submodule_sync --skip_tests\ncp build/Linux/Debug/ort-wasm*.wasm js/web/dist/\ncp build/Linux/Debug/ort-wasm*.js js/web/lib/wasm/binding/\ncd js/web\nNODE_OPTIONS=--max-old-space-size=4096 npm run build\n```\n\nYou might get some errors but if you see ort.js and ort-web.js in the dist/ folder, then it should work.\n\n### 1. Setup the example project.\n\n   0. (If you built ONNX Runtime Web yourself)\n   Put the files from the ONNX Runtime Web build (ort.js and others such as the wasm files, if needed) in `training/public/onnxruntime_web_build_inference_with_training_ops/`:\n   ```bash\n   # In the onnxruntime root directory, do:\n   rm \u003cyour workspace\u003e/train-pytorch-in-js/training/public/onnxruntime_web_build_inference_with_training_ops/*.{js,wasm}\n   cp js/web/dist/* js/web/lib/wasm/binding/* \u003cyour workspace\u003e/train-pytorch-in-js/training/public/onnxruntime_web_build_inference_with_training_ops\n   # Get the declaration files.\n   cp js/common/dist/lib/*.d.ts \u003cyour workspace\u003e/FL/train-pytorch-in-js/training/src/ort\n   ```\n   1. Copy some files to `training/public/`:\n   ```bash\n   cp *_graph.onnx training/public\n   ```\n\n   Copy the MNIST data:\n   ```bash\n   cd export\n   cp -R ../data training/public/\n   ```\n   2. Go to the `training` folder:\\\n   `cd training`\n   3. Run `yarn install`\n   4. Run `yarn start`\\\n   Your browser should open.\n   Click \"TRAIN\" to train the model.\n\n[mnist]: https://deepai.org/dataset/mnist\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjuharris%2Ftrain-pytorch-in-js","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjuharris%2Ftrain-pytorch-in-js","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjuharris%2Ftrain-pytorch-in-js/lists"}