{"id":34723862,"url":"https://github.com/ekzhang/jax-js","last_synced_at":"2026-04-05T20:04:34.892Z","repository":{"id":330263110,"uuid":"926853438","full_name":"ekzhang/jax-js","owner":"ekzhang","description":"JAX in JavaScript – ML library for the web, running on WebGPU \u0026 Wasm","archived":false,"fork":false,"pushed_at":"2026-03-27T03:38:14.000Z","size":1704,"stargazers_count":762,"open_issues_count":11,"forks_count":41,"subscribers_count":9,"default_branch":"main","last_synced_at":"2026-03-29T20:57:56.645Z","etag":null,"topics":["javascript","jax","jit","machine-learning","neural-networks","numpy","wasm","webgl","webgpu"],"latest_commit_sha":null,"homepage":"https://jax-js.com","language":"TypeScript","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/ekzhang.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2025-02-04T01:00:49.000Z","updated_at":"2026-03-28T17:20:07.000Z","dependencies_parsed_at":null,"dependency_job_id":null,"html_url":"https://github.com/ekzhang/jax-js","commit_stats":null,"previous_names":["ekzhang/jax-js"],"tags_count":25,"template":false,"template_full_name":null,"purl":"pkg:github/ekzhang/jax-js","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ekzhang%2Fjax-js","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ekzhang%2Fjax-js/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ekzhang%2Fjax-js/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ekzhang%2Fjax-js/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/ekzhang","download_url":"https://codeload.github.com/ekzhang/jax-js/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ekzhang%2Fjax-js/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":31448219,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-04-05T15:22:31.103Z","status":"ssl_error","status_checked_at":"2026-04-05T15:22:00.205Z","response_time":75,"last_error":"SSL_connect returned=1 errno=0 peeraddr=140.82.121.5:443 state=error: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["javascript","jax","jit","machine-learning","neural-networks","numpy","wasm","webgl","webgpu"],"created_at":"2025-12-25T02:15:23.870Z","updated_at":"2026-04-05T20:04:34.886Z","avatar_url":"https://github.com/ekzhang.png","language":"TypeScript","readme":"\u003ch1 align=\"center\"\u003ejax-js: JAX in pure JavaScript\u003c/h1\u003e\n\n\u003cp align=\"center\"\u003e\u003cstrong\u003e\n  \u003ca href=\"https://jax-js.com\"\u003eWebsite\u003c/a\u003e |\n  \u003ca href=\"https://jax-js.com/docs/\"\u003eAPI Reference\u003c/a\u003e |\n  \u003ca href=\"./FEATURES.md\"\u003eCompatibility Table\u003c/a\u003e |\n  \u003ca href=\"https://discord.gg/BW6YsCd4Tf\"\u003eDiscord\u003c/a\u003e\n\u003c/strong\u003e\u003c/p\u003e\n\n**jax-js** is a machine learning framework for the browser. It aims to bring\n[JAX](https://jax.dev)-style, high-performance CPU and GPU kernels to JavaScript, so you can run\nnumerical applications on the web.\n\n```bash\nnpm i @jax-js/jax\n```\n\nUnder the hood, it translates array operations into a compiler representation, then synthesizes\nkernels in WebAssembly and WebGPU.\n\nThe library is written from scratch, with zero external dependencies. It maintains close API\ncompatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable\nGPU ML framework, since it runs anywhere a browser can run.\n\n## Quickstart\n\n```js\nimport { numpy as np } from \"@jax-js/jax\";\n\n// Array operations, compatible with JAX/NumPy.\nconst x = np.array([1, 2, 3]);\nconst y = x.mul(4); // [4, 8, 12]\n```\n\n### Web usage (CDN)\n\nIn vanilla JavaScript (without a bundler), just import from a module script tag. This is the easiest\nway to get started on a blank HTML page.\n\n```html\n\u003cscript type=\"module\"\u003e\n  import { numpy as np } from \"https://esm.sh/@jax-js/jax\";\n\u003c/script\u003e\n```\n\n### Platforms\n\nThis table refers to latest versions of each browser. WebGPU has gained wide support in browsers as\nof late 2025.\n\n| Platform            | CPU (Wasm) | GPU (WebGPU)   | GPU (WebGL) |\n| ------------------- | ---------- | -------------- | ----------- |\n| Chrome / Edge       | ✅         | ✅             | ✅          |\n| Firefox             | ✅         | ✅ - macOS 26+ | ✅          |\n| Safari              | ✅         | ✅ - macOS 26+ | ✅          |\n| iOS                 | ✅         | ✅ - iOS 26+   | ✅          |\n| Chrome for Android  | ✅         | ✅             | ✅          |\n| Firefox for Android | ✅         | ❌             | ✅          |\n| Node.js             | ✅         | ❌             | ❌          |\n| Deno                | ✅         | ✅ - async     | ❌          |\n\n## Examples\n\nCommunity usage:\n\n- [**autoresearch-webgpu**: autoresesarch, in the browser](https://autoresearch.lucasgelfond.online/)\n- [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)\n- [**jax-js-bayes**: Declarative Bayesian modeling library](https://github.com/StefanSko/jax-js-bayes)\n\nDemos on the jax-js website:\n\n- [Training neural networks on MNIST](https://jax-js.com/mnist)\n- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)\n- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)\n- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)\n- [Fluid simulation (Navier-Stokes)](https://jax-js.com/fluid-sim)\n- [In-browser REPL](https://jax-js.com/repl)\n- [Matmul benchmark](https://jax-js.com/bench/matmul)\n- [Conv2d benchmark](https://jax-js.com/bench/conv2d)\n- [Mandelbrot set](https://jax-js.com/mandelbrot)\n\n## Feature comparison\n\nHere's a quick, high-level comparison with other popular web ML runtimes:\n\n| Feature                         | jax-js     | TensorFlow.js   | onnxruntime-web    |\n| ------------------------------- | ---------- | --------------- | ------------------ |\n| **Overview**                    |            |                 |                    |\n| API style                       | JAX/NumPy  | TensorFlow-like | Static ONNX graphs |\n| Latest release                  | 2026       | ⚠️ 2024         | 2026               |\n| Speed                           | Fastest    | Fast            | Fastest            |\n| Bundle size (gzip)              | 80 KB      | 269 KB          | 90 KB + 24 MB Wasm |\n| **Autodiff \u0026 JIT**              |            |                 |                    |\n| Gradients                       | ✅         | ✅              | ❌                 |\n| Jacobian and Hessian            | ✅         | ❌              | ❌                 |\n| `jvp()` forward differentiation | ✅         | ❌              | ❌                 |\n| `jit()` kernel fusion           | ✅         | ❌              | ❌                 |\n| `vmap()` auto-vectorization     | ✅         | ❌              | ❌                 |\n| Graph capture                   | ✅         | ❌              | ✅                 |\n| **Backends \u0026 Data**             |            |                 |                    |\n| WebGPU backend                  | ✅         | 🟡 Preview      | ✅                 |\n| WebGL backend                   | ✅         | ✅              | ✅                 |\n| Wasm (CPU) backend              | ✅         | ✅              | ✅                 |\n| Eager array API                 | ✅         | ✅              | ❌                 |\n| Run ONNX models                 | 🟡 Partial | ❌              | ✅                 |\n| Read safetensors                | ✅         | ❌              | ❌                 |\n| Float64                         | ✅         | ❌              | ❌                 |\n| Float32                         | ✅         | ✅              | ✅                 |\n| Float16                         | ✅         | ❌              | ✅                 |\n| BFloat16                        | ❌         | ❌              | ❌                 |\n| Packed Uint8                    | ❌         | ❌              | 🟡 Partial         |\n| Mixed precision                 | ✅         | ❌              | ✅                 |\n| Mixed devices                   | ✅         | ❌              | ❌                 |\n| **Ops \u0026 Numerics**              |            |                 |                    |\n| Arithmetic functions            | ✅         | ✅              | ✅                 |\n| Matrix multiplication           | ✅         | ✅              | ✅                 |\n| General einsum                  | ✅         | 🟡 Partial      | 🟡 Partial         |\n| Sorting                         | ✅         | ❌              | ❌                 |\n| Activation functions            | ✅         | ✅              | ✅                 |\n| NaN/Inf numerics                | ✅         | ✅              | ✅                 |\n| Basic convolutions              | ✅         | ✅              | ✅                 |\n| n-d convolutions                | ✅         | ❌              | ✅                 |\n| Strided/dilated convolution     | ✅         | ✅              | ✅                 |\n| Cholesky, Lstsq                 | ✅         | ❌              | ❌                 |\n| LU, Solve, Determinant          | ✅         | ❌              | ❌                 |\n| SVD                             | ❌         | ❌              | ❌                 |\n| FFT                             | ✅         | ✅              | ✅                 |\n| Basic RNG (Uniform, Normal)     | ✅         | ✅              | ✅                 |\n| Advanced RNG                    | ✅         | ❌              | ❌                 |\n\n## Tutorial\n\nProgramming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),\njust in JavaScript.\n\n### Arrays\n\nCreate an array with `np.array()`:\n\n```ts\nimport { numpy as np } from \"@jax-js/jax\";\n\nconst ar = np.array([1, 2, 3]);\n```\n\nBy default, this is a float32 array, but you can specify a different dtype:\n\n```ts\nconst ar = np.array([1, 2, 3], { dtype: np.int32 });\n```\n\nFor more efficient construction, create an array from a JS `TypedArray` buffer:\n\n```ts\nconst buf = new Float32Array([10, 20, 30, 100, 200, 300]);\nconst ar = np.array(buf).reshape([2, 3]);\n```\n\nOnce you're done with it, you can unwrap a `jax.Array` back into JavaScript. This will also apply\nany pending operations or lazy updates:\n\n```ts\n// 1) Returns a possibly nested JavaScript array.\nar.js();\nawait ar.jsAsync(); // Faster, non-blocking\n\n// 2) Returns a flat TypedArray data buffer.\nar.dataSync();\nawait ar.data(); // Fastest, non-blocking\n```\n\nArrays can have mathematical operations applied to them. For example:\n\n```ts\nimport { numpy as np, scipySpecial as special } from \"@jax-js/jax\";\n\nconst x = np.arange(100).astype(np.float32); // array of integers [0..99]\n\nconst y1 = x.ref.add(x.ref); // x + x\nconst y2 = np.sin(x.ref); // sin(x)\nconst y3 = np.tanh(x.ref).mul(5); // 5 * tanh(x)\nconst y4 = special.erfc(x.ref); // erfc(x)\n```\n\nNotice that in the above code, we used `x.ref`. This is because of the memory model, jax-js uses\nreference-counted _ownership_ to track when the memory of an Array can be freed. More on this below.\n\n### Reference counting\n\nBig Arrays take up a lot of memory. Python ML libraries override the `__del__()` method to free\nmemory, but JavaScript has no such API for running object destructors\n([cf.](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry)).\nThis means that you have to track references manually. jax-js tries to make this as ergonomic as\npossible, so you don't accidentally leak memory in a loop.\n\nEvery `jax.Array` has a reference count. This satisfies the following rules:\n\n- Whenever you create an Array, its reference count starts at `1`.\n- When an Array's reference count reaches `0`, it is freed and can no longer be used.\n- Given an Array `a`:\n  - Accessing `a.ref` returns `a` and changes its reference count by `+1`.\n  - Passing `a` into any function as argument changes its reference count by `-1`.\n  - Calling `a.dispose()` also changes its reference count by `-1`.\n\nWhat this means is that all functions in jax-js must _take ownership_ of their arguments as\nreferences. Whenever you would like to pass an Array as argument, you can pass it directly to\ndispose of it, or use `.ref` if you'd like to use it again later.\n\n**You must follow these rules on your own functions as well!** All combinators like `jvp`, `grad`,\n`jit` assume that you are following these conventions on how arguments are passed, and they will\nrespect them as well.\n\n```ts\n// Bad: Uses `x` twice, decrementing its reference count twice.\nfunction foo_bad(x: np.Array, y: np.Array) {\n  return x.add(x.mul(y));\n}\n\n// Good: The first usage of `x` is `x.ref`, adding +1 to refcount.\nfunction foo_good(x: np.Array, y: np.Array) {\n  return x.ref.add(x.mul(y));\n}\n```\n\nHere's another example:\n\n```ts\n// Bad: Doesn't consume `x` in the `if`-branch.\nfunction bar_bad(x: np.Array, skip: boolean) {\n  if (skip) return np.zeros(x.shape);\n  return x;\n}\n\n// Good: Consumes `x` the one time in each branch.\nfunction bar_good(x: np.Array, skip: boolean) {\n  if (skip) {\n    const ret = np.zeros(x.shape);\n    x.dispose();\n    return ret;\n  }\n  return x;\n}\n```\n\nYou can assume that every function in jax-js takes ownership properly, except with a couple of very\nrare exceptions that are documented.\n\n### grad(), vmap() and jit()\n\nJAX's signature composable transformations are also supported in jax-js. Here is a simple example of\nusing `grad` and `vmap` to compute the derivaive of a function:\n\n```ts\nimport { numpy as np, grad, vmap } from \"@jax-js/jax\";\n\nconst x = np.linspace(-10, 10, 1000);\n\nconst y1 = vmap(grad(np.sin))(x.ref); // d/dx sin(x) = cos(x)\nconst y2 = np.cos(x);\n\nnp.allclose(y1, y2); // =\u003e true\n```\n\nThe `jit` function is especially useful when doing long sequences of primitives on GPU, since it\nfuses operations together into a single kernel dispatch. This\n[improves memory bandwidth usage](https://substack.com/home/post/p-163548742) on hardware\naccelerators, which is the bottleneck on GPU rather than raw FLOPs. For instance:\n\n```ts\nexport const hypot = jit(function hypot(x1: np.Array, x2: np.Array) {\n  return np.sqrt(np.square(x1).add(np.square(x2)));\n});\n```\n\nWithout JIT, the `hypot()` function would require four kernel dispatches: two multiplies, one add,\nand one sqrt. JIT fuses these together into a single kernel that does it all at once.\n\nAll functional transformations can take typed `JsTree` of inputs and outputs. These are similar to\n[JAX's pytrees](https://docs.jax.dev/en/latest/pytrees.html), and it's basically just a structure of\nnested JavaScript objects and arrays. For instance:\n\n```ts\nimport { grad, numpy as np } from \"@jax-js/jax\";\n\ntype Params = {\n  foo: np.Array;\n  bar: np.Array[];\n};\n\nfunction getSums(p: Params) {\n  const fooSum = p.foo.sum();\n  const barSum = p.bar.map((x) =\u003e x.sum()).reduce(np.add);\n  return fooSum.add(barSum);\n}\n\ngrad(getSums)({\n  foo: np.array([1, 2, 3]),\n  bar: [np.array([10]), np.array([11, 12])],\n});\n// =\u003e { foo: [1, 1, 1], bar: [[1], [1, 1]] }\n```\n\nNote that you need to use `type` alias syntax rather than `interface` to define fine-grained\n`JsTree` types.\n\n### Devices\n\nSimilar to JAX, jax-js has a concept of \"devices\" which are a backend that stores Arrays in memory\nand determines how to execute compiled operations on them.\n\nThere are currently 4 devices in jax-js:\n\n- `cpu`: Slow, interpreted JS, only meant for debugging.\n- `wasm`: [WebAssembly](https://webassembly.org/), multi-threaded when\n  [`SharedArrayBuffer`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer)\n  is available.\n- `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on\n  [supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).\n- `webgl`: [WebGL2](https://developer.mozilla.org/en-US/docs/Web/API/WebGL2RenderingContext), via\n  fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much\n  slower than WebGPU. It's offered on a best-effort basis and not as well-supported.\n\n**We recommend `webgpu` for best performance, especially when running neural networks.** The default\ndevice is `wasm`, but you can change this at startup time:\n\n```ts\nimport { defaultDevice, init } from \"@jax-js/jax\";\n\nconst devices = await init(); // Starts all available backends.\n\nif (devices.includes(\"webgpu\")) {\n  defaultDevice(\"webgpu\");\n} else {\n  console.warn(\"WebGPU is not supported, falling back to Wasm.\");\n}\n```\n\nYou can also place individual arrays on specific devices:\n\n```ts\nimport { devicePut, numpy as np } from \"@jax-js/jax\";\n\nconst ar = np.array([1, 2, 3]); // Starts with device=\"wasm\"\nawait devicePut(ar, \"webgpu\"); // Now device=\"webgpu\"\n```\n\n### Helper libraries\n\nThere are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a\nself-contained way in other projects.\n\n- [**`@jax-js/loaders`**](packages/loaders) can load tensors from various formats like Safetensors,\n  includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets\n  like model weights in OPFS.\n- [**`@jax-js/onnx`**](packages/onnx) is a model loader from the [ONNX](https://onnx.ai/) format\n  into native jax-js functions.\n- [**`@jax-js/optax`**](packages/optax) provides implementations of optimizers like Adam and SGD.\n\n### Performance\n\nTo see per-kernel traces in browser development tools, call `jax.profiler.startTrace()`.\n\nThe WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual\nbrowsers. Also, this library uniquely has the `jit()` feature that fuses operations together and\nrecords an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an\nApple M4 Max chip ([try it](https://jax-js.com/bench/matmul)).\n\nFor that example, it's significantly faster than both\n[TensorFlow.js](https://github.com/tensorflow/tfjs) and\n[ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten\nlibraries of custom kernels.\n\nIt's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as\nwell as unique optimizations such as FlashAttention variants.\n\n### API Reference\n\nThat's all for this short tutorial. Please see the generated\n[API reference](https://jax-js.com/docs) for detailed documentation.\n\n## Development\n\n_The following technical details are for contributing to jax-js and modifying its internals._\n\nThis repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in\nwatch mode with:\n\n```bash\npnpm install\npnpm run build:watch\n```\n\nThen you can run tests in a headless browser using [Vitest](https://vitest.dev/).\n\n```bash\npnpm exec playwright install\npnpm test\n```\n\nWe are currently on an older version of Playwright that supports using WebGPU in headless mode;\nnewer versions skip the WebGPU tests.\n\nTo start a Vite dev server running the website, demos and REPL:\n\n```bash\npnpm -C website dev\n```\n\nYou can run the linter, code formatter, and type checker with:\n\n```bash\npnpm lint          # Run ESLint\npnpm format        # Format all files with Prettier\npnpm format:check  # Check formatting without writing\npnpm check         # Run TypeScript type checking\n```\n\n## Future work / help wanted\n\nContributions are welcomed! Some fruitful areas to look into:\n\n- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).\n- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD\n  and multithreading. (Even single-threaded Wasm could be ~20x faster.)\n- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.\n- Making a fast transformer inference engine, comparing against onnxruntime-web.\n\nYou may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fekzhang%2Fjax-js","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fekzhang%2Fjax-js","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fekzhang%2Fjax-js/lists"}