{"id":13409157,"url":"https://github.com/google/neural-tangents","last_synced_at":"2025-05-14T11:08:50.973Z","repository":{"id":37270831,"uuid":"180192894","full_name":"google/neural-tangents","owner":"google","description":"Fast and Easy Infinite Neural Networks in Python","archived":true,"fork":false,"pushed_at":"2024-03-01T17:17:03.000Z","size":10932,"stargazers_count":2341,"open_issues_count":69,"forks_count":231,"subscribers_count":60,"default_branch":"main","last_synced_at":"2025-05-13T09:08:10.220Z","etag":null,"topics":["bayesian-inference","bayesian-networks","deep-networks","gaussian-processes","gradient-descent","gradient-flow","infinite-networks","jax","kernel","kernel-computation","neural-networks","neural-tangents","training-dynamics"],"latest_commit_sha":null,"homepage":"https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/google.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":"CITATION","codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2019-04-08T16:48:48.000Z","updated_at":"2025-05-11T14:08:39.000Z","dependencies_parsed_at":"2024-01-07T10:51:01.550Z","dependency_job_id":"eab8de26-c15f-40f0-b67e-32590cca0d5a","html_url":"https://github.com/google/neural-tangents","commit_stats":{"total_commits":583,"total_committers":25,"mean_commits":23.32,"dds":0.3276157804459692,"last_synced_commit":"9cfdc2878f7270bb02973cc2b438c81a7a39c315"},"previous_names":[],"tags_count":17,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google%2Fneural-tangents","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google%2Fneural-tangents/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google%2Fneural-tangents/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google%2Fneural-tangents/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/google","download_url":"https://codeload.github.com/google/neural-tangents/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254129481,"owners_count":22019628,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["bayesian-inference","bayesian-networks","deep-networks","gaussian-processes","gradient-descent","gradient-flow","infinite-networks","jax","kernel","kernel-computation","neural-networks","neural-tangents","training-dynamics"],"created_at":"2024-07-30T20:00:58.431Z","updated_at":"2025-05-14T11:08:50.938Z","avatar_url":"https://github.com/google.png","language":"Jupyter Notebook","funding_links":[],"categories":["Toolbox","Jupyter Notebook","机器学习框架","其他_机器学习与深度学习","\u003cspan id=\"head41\"\u003e3.5. Machine Learning and Deep Learning\u003c/span\u003e","JAX Models","Libraries","Computational Fluid Dynamics","DeepCNN"],"sub_categories":["Libraries","\u003cspan id=\"head48\"\u003e3.5.7. Neural Tangent\u003c/span\u003e","Reinforcement Learning","Neural Networks for PDE","Interpretation"],"readme":"# **Stand with Ukraine!** 🇺🇦\n\nFreedom of thought is fundamental to all of science. Right now, our freedom is being suppressed with bombing of civilians in Ukraine. **Don't be against the war - fight against the war! [supportukrainenow.org](https://supportukrainenow.org/)**.\n\n# Neural Tangents\n[**ICLR 2020 Video**](https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html)\n| [**Paper**](https://arxiv.org/abs/1912.02803)\n| [**Quickstart**](#colab-notebooks)\n| [**Install guide**](#installation)\n| [**Reference docs**](https://neural-tangents.readthedocs.io/en/latest/)\n| [**Release notes**](https://github.com/google/neural-tangents/releases)\n\n[![PyPI](https://img.shields.io/pypi/v/neural-tangents)](https://pypi.org/project/neural-tangents/) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/neural-tangents)](https://pypi.org/project/neural-tangents/)\n[![Linux](https://github.com/google/neural-tangents/actions/workflows/linux.yml/badge.svg)](https://github.com/google/neural-tangents/actions/workflows/linux.yml)\n[![macOS](https://github.com/google/neural-tangents/actions/workflows/macos.yml/badge.svg)](https://github.com/google/neural-tangents/actions/workflows/macos.yml)\n[![Pytype](https://github.com/google/neural-tangents/actions/workflows/pytype.yml/badge.svg)](https://github.com/google/neural-tangents/actions/workflows/pytype.yml)\n[![Coverage](https://codecov.io/gh/google/neural-tangents/branch/main/graph/badge.svg)](https://codecov.io/gh/google/neural-tangents)\n[![Readthedocs](https://readthedocs.org/projects/neural-tangents/badge/?version=latest)](https://neural-tangents.readthedocs.io/en/latest/?badge=latest)\n\n[//]: # ([![PyPI - License]\u0026#40;https://img.shields.io/pypi/l/neural_tangents\u0026#41;]\u0026#40;https://github.com/google/neural-tangents/blob/main/LICENSE\u0026#41;)\n\n\n## Overview\n\nNeural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and _infinite_ width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones. The library has been used in [\u003e100 papers](https://scholar.google.com/scholar?oi=bibs\u0026hl=en\u0026cites=4030630874639258770,4161931758707925692,2891750348147928089,8612471018033907356,10117604240015578443,4178323439418493877).\n\nInfinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture. See [this listing](https://github.com/google/neural-tangents/wiki/Overparameterized-Neural-Networks:-Theory-and-Empirics) of papers written by the creators of Neural Tangents which study the infinite width limit of neural networks.\n\nNeural Tangents allows you to construct a neural network model from common building blocks like convolutions, pooling, residual connections, nonlinearities, and more, and obtain not only the finite model, but also the kernel function of the respective GP.\n\nThe library is written in python using [JAX](https://github.com/google/jax) and leveraging [XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/index.md) to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.\n\nNeural Tangents is a work in progress.\nWe happily welcome contributions!\n\n\n\n\n## Contents\n* [Colab Notebooks](#colab-notebooks)\n* [Installation](#installation)\n* [5-Minute intro](#5-minute-intro)\n* [Package description](#package-description)\n* [Technical gotchas](#technical-gotchas)\n* [Training dynamics of wide but finite networks](#training-dynamics-of-wide-but-finite-networks)\n* [Performance](#performance)\n* [Citation](#citation)\n\n## Colab Notebooks\n\nAn easy way to get started with Neural Tangents is by playing around with the following interactive notebooks in Colaboratory. They demo the major features of Neural Tangents and show how it can be used in research.\n\n- [Neural Tangents Cookbook](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/neural_tangents_cookbook.ipynb)\n- [Weight Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/weight_space_linearization.ipynb)\n- [Function Space Linearization](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/function_space_linearization.ipynb)\n- [Neural Network Phase Diagram](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/phase_diagram.ipynb)\n- [Performance Benchmark](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/myrtle_kernel_with_neural_tangents.ipynb): simple benchmark for [Myrtle kernels](https://arxiv.org/abs/2003.02237). See also [Performance](#myrtle-network)\n- [**New**] Empirical NTK:\n  - [Fully-connected network](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_fcn.ipynb)\n  - [FLAX ResNet18](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb)\n  - [Experimental: Tensorflow ResNet50](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/experimental/empirical_ntk_resnet_tf.ipynb)\n- [**New**] [Automatic NNGP/NTK of elementwise nonlinearities](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/elementwise.ipynb)\n\n\n## Installation\n\nTo use GPU, first follow [JAX's](https://www.github.com/google/jax/) GPU installation instructions. Otherwise, install JAX on CPU by running\n\n```\npip install jax jaxlib --upgrade\n```\n\nOnce JAX is installed install Neural Tangents by running\n\n```\npip install neural-tangents\n```\nor, to use the bleeding-edge version from GitHub source,\n\n```\ngit clone https://github.com/google/neural-tangents; cd neural-tangents\npip install -e .\n```\n\nYou can now run the examples and tests by calling:\n\n```\npip install .[testing]\nset -e; for f in examples/*.py; do python $f; done  # Run examples\nset -e; for f in tests/*.py; do python $f; done  # Run tests\n```\n\n\n## 5-Minute intro\n\n\u003cb\u003eSee this [Colab](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/neural_tangents_cookbook.ipynb) for a detailed tutorial. Below is a very quick introduction.\u003c/b\u003e\n\nOur library closely follows JAX's API for specifying neural networks,  [`stax`](https://github.com/google/jax/blob/main/jax/example_libraries/stax.py). In `stax` a network is defined by a pair of functions `(init_fn, apply_fn)` initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing its outputs `y` given inputs `x`.\n\n```python\nfrom jax import random\nfrom jax.example_libraries import stax\n\ninit_fn, apply_fn = stax.serial(\n    stax.Dense(512), stax.Relu,\n    stax.Dense(512), stax.Relu,\n    stax.Dense(1)\n)\n\nkey = random.PRNGKey(1)\nx = random.normal(key, (10, 100))\n_, params = init_fn(key, input_shape=x.shape)\n\ny = apply_fn(params, x)  # (10, 1) jnp.ndarray outputs of the neural network\n```\n\nNeural Tangents is designed to serve as a drop-in replacement for `stax`, extending the `(init_fn, apply_fn)` tuple to a triple `(init_fn, apply_fn, kernel_fn)`, where `kernel_fn` is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs `x1` and `x2`.\n\n```python\nfrom jax import random\nfrom neural_tangents import stax\n\ninit_fn, apply_fn, kernel_fn = stax.serial(\n    stax.Dense(512), stax.Relu(),\n    stax.Dense(512), stax.Relu(),\n    stax.Dense(1)\n)\n\nkey1, key2 = random.split(random.PRNGKey(1))\nx1 = random.normal(key1, (10, 100))\nx2 = random.normal(key2, (20, 100))\n\nkernel = kernel_fn(x1, x2, 'nngp')\n```\n\nNote that `kernel_fn` can compute _two_ covariance matrices corresponding to the [Neural Network Gaussian Process (NNGP)](https://en.wikipedia.org/wiki/Neural_network_Gaussian_process) and [Neural Tangent (NT)](https://en.wikipedia.org/wiki/Neural_tangent_kernel) kernels respectively. The NNGP kernel corresponds to the _Bayesian_ infinite neural network. The NTK corresponds to the _(continuous) gradient descent trained_ infinite network. In the above example, we compute the NNGP kernel, but we could compute the NTK or both:\n\n```python\n# Get kernel of a single type\nnngp = kernel_fn(x1, x2, 'nngp') # (10, 20) jnp.ndarray\nntk = kernel_fn(x1, x2, 'ntk') # (10, 20) jnp.ndarray\n\n# Get kernels as a namedtuple\nboth = kernel_fn(x1, x2, ('nngp', 'ntk'))\nboth.nngp == nngp  # True\nboth.ntk == ntk  # True\n\n# Unpack the kernels namedtuple\nnngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))\n```\n\nAdditionally, if no third-argument is specified then the `kernel_fn` will return a `Kernel` namedtuple that contains additional metadata. This can be useful for composing applications of `kernel_fn` as follows:\n\n```python\nkernel = kernel_fn(x1, x2)\nkernel = kernel_fn(kernel)\nprint(kernel.nngp)\n```\n\nDoing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:\n\n```python\nimport neural_tangents as nt\n\nx_train, x_test = x1, x2\ny_train = random.uniform(key1, shape=(10, 1))  # training targets\n\npredict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,\n                                                      y_train)\n\ny_test_nngp = predict_fn(x_test=x_test, get='nngp')\n# (20, 1) jnp.ndarray test predictions of an infinite Bayesian network\n\ny_test_ntk = predict_fn(x_test=x_test, get='ntk')\n# (20, 1) jnp.ndarray test predictions of an infinite continuous\n# gradient descent trained network at convergence (t = inf)\n\n# Get predictions as a namedtuple\nboth = predict_fn(x_test=x_test, get=('nngp', 'ntk'))\nboth.nngp == y_test_nngp  # True\nboth.ntk == y_test_ntk  # True\n\n# Unpack the predictions namedtuple\ny_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))\n```\n\n\n### Infinitely WideResnet\n\nWe can define a more complex, (infinitely) [Wide Residual Network](https://arxiv.org/abs/1605.07146) using the same `nt.stax` building blocks:\n\n```python\nfrom neural_tangents import stax\n\ndef WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):\n  Main = stax.serial(\n      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),\n      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))\n  Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(\n      channels, (3, 3), strides, padding='SAME')\n  return stax.serial(stax.FanOut(2),\n                     stax.parallel(Main, Shortcut),\n                     stax.FanInSum())\n\ndef WideResnetGroup(n, channels, strides=(1, 1)):\n  blocks = []\n  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]\n  for _ in range(n - 1):\n    blocks += [WideResnetBlock(channels, (1, 1))]\n  return stax.serial(*blocks)\n\ndef WideResnet(block_size, k, num_classes):\n  return stax.serial(\n      stax.Conv(16, (3, 3), padding='SAME'),\n      WideResnetGroup(block_size, int(16 * k)),\n      WideResnetGroup(block_size, int(32 * k), (2, 2)),\n      WideResnetGroup(block_size, int(64 * k), (2, 2)),\n      stax.AvgPool((8, 8)),\n      stax.Flatten(),\n      stax.Dense(num_classes, 1., 0.))\n\ninit_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)\n```\n\n\n## Package description\n\nThe `neural_tangents` (`nt`) package contains the following modules and functions:\n\n* `stax` - primitives to construct neural networks like `Conv`, `Relu`, `serial`, `parallel` etc.\n\n* `predict` - predictions with infinite networks:\n\n  * `predict.gradient_descent_mse` - inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (`t=None`) time. Computed in closed form.\n\n  * `predict.gradient_descent` - inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver.\n\n  * `predict.gradient_descent_mse_ensemble` - inference with an infinite ensemble of infinite width networks, either fully Bayesian (`get='nngp'`) or inference with MSE loss using continuous gradient descent (`get='ntk'`). Finite-time Bayesian inference (e.g. `t=1., get='nngp'`) is interpreted as [gradient descent on the top layer only](https://arxiv.org/abs/1902.06720), since it converges to exact Gaussian process inference with NNGP (`t=None, get='nngp'`). Computed in closed form.\n\n  * `predict.gp_inference` - exact closed form Gaussian process inference using NNGP (`get='nngp'`), NTK (`get='ntk'`), or both (`get=('nngp', 'ntk')`). Equivalent to `predict.gradient_descent_mse_ensemble` with `t=None` (infinite training time), but has a slightly different API (accepting precomputed kernel matrix `k_train_train` instead of `kernel_fn` and `x_train`).\n\n* `monte_carlo_kernel_fn` - compute a Monte Carlo kernel estimate  of _any_ `(init_fn, apply_fn)`, not necessarily specified via `nt.stax`, enabling the kernel computation of infinite networks without closed-form expressions.\n\n* Tools to investigate training dynamics of _wide but finite_ neural networks, like `linearize`, `taylor_expand`, `empirical_kernel_fn` and more. See [Training dynamics of wide but finite networks](#training-dynamics-of-wide-but-finite-networks) for details.\n\n\n## Technical gotchas\n\n\n### [`nt.stax`](https://github.com/google/neural-tangents/blob/main/neural_tangents/stax.py) vs [`jax.example_libraries.stax`](https://github.com/google/jax/blob/main/jax/example_libraries/stax.py)\nWe remark the following differences between our library and the JAX one.\n\n* All `nt.stax` layers are instantiated with a function call, i.e. `nt.stax.Relu()` vs `jax.example_libraries.stax.Relu`.\n* All layers with trainable parameters use the [_NTK parameterization_](https://arxiv.org/abs/1806.07572) by default. However, `Dense` and `Conv` layers also support the [_standard parameterization_](https://arxiv.org/abs/2001.07301) via a `parameterization` keyword argument.\n* `nt.stax` and `jax.example_libraries.stax` may have different layers and options available (for example `nt.stax` layers support `CIRCULAR` padding, have `LayerNorm`, but no `BatchNorm`.).\n\n\n### CPU and TPU performance\n\nFor CNNs w/ pooling, our CPU and TPU performance is suboptimal due to low core\nutilization (10-20%, looks like an XLA:CPU issue), and excessive padding\nrespectively. We will look into improving performance, but recommend NVIDIA GPUs\nin the meantime. See [Performance](#performance).\n\n\n## Training dynamics of wide but finite networks\n\nThe kernel of an infinite network `kernel_fn(x1, x2).ntk` combined with  `nt.predict.gradient_descent_mse` together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss throughout training. Here we discuss the implications for _wide but finite_ neural networks and present tools to study their evolution in _weight space_ (trainable parameters of the network) and _function space_ (outputs of the network).\n\n### Weight space\n\nContinuous gradient descent in an infinite network [has been shown in](https://arxiv.org/abs/1902.06720) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.\n\nFor this, we provide two convenient functions:\n\n* `nt.linearize`, and\n* `nt.taylor_expand`,\n\nwhich allow us to linearize or get an arbitrary-order Taylor expansion of any function `apply_fn(params, x)` around some initial parameters `params_0` as `apply_fn_lin = nt.linearize(apply_fn, params_0)`.\n\nOne can use `apply_fn_lin(params, x)` exactly as you would any other function\n(including as an input to JAX optimizers). This makes it easy to compare the\ntraining trajectory of neural networks with that of its linearization.\nPrior theory and experiments have examined the linearization of neural\nnetworks from inputs to logits or pre-activations, rather than from inputs to\npost-activations which are substantially more nonlinear.\n\n#### Example:\n\n```python\nimport jax.numpy as jnp\nimport neural_tangents as nt\n\ndef apply_fn(params, x):\n  W, b = params\n  return jnp.dot(x, W) + b\n\nW_0 = jnp.array([[1., 0.], [0., 1.]])\nb_0 = jnp.zeros((2,))\n\napply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))\nW = jnp.array([[1.5, 0.2], [0.1, 0.9]])\nb = b_0 + 0.2\n\nx = jnp.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])\nlogits = apply_fn_lin((W, b), x)  # (3, 2) jnp.ndarray\n```\n\n### Function space:\n\nOutputs of a linearized model [evolve identically to those of an infinite one](https://arxiv.org/abs/1902.06720) but with a different kernel - precisely, the [Neural Tangent Kernel](https://arxiv.org/1806.07572) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, get, params)` that allows to compute the empirical NTK and/or NNGP (based on `get`) kernels on specific `params`.\n\n#### Example:\n\n```python\nimport jax.random as random\nimport jax.numpy as jnp\nimport neural_tangents as nt\n\n\ndef apply_fn(params, x):\n  W, b = params\n  return jnp.dot(x, W) + b\n\n\nW_0 = jnp.array([[1., 0.], [0., 1.]])\nb_0 = jnp.zeros((2,))\nparams = (W_0, b_0)\n\nkey1, key2 = random.split(random.PRNGKey(1), 2)\nx_train = random.normal(key1, (3, 2))\nx_test = random.normal(key2, (4, 2))\ny_train = random.uniform(key1, shape=(3, 2))\n\nkernel_fn = nt.empirical_kernel_fn(apply_fn)\nntk_train_train = kernel_fn(x_train, None, 'ntk', params)\nntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)\nmse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)\n\nt = 5.\ny_train_0 = apply_fn(params, x_train)\ny_test_0 = apply_fn(params, x_test)\ny_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)\n# (3, 2) and (4, 2) jnp.ndarray train and test outputs after `t` units of time\n# training with continuous gradient descent\n```\n\n### What to Expect\n\nThe success or failure of the linear approximation is highly architecture\ndependent. However, some rules of thumb that we've observed are:\n\n1. Convergence as the network size increases.\n\n   * For fully-connected networks one generally observes very strong\n     agreement by the time the layer-width is 512 (RMSE of about 0.05 at the\n     end of training).\n\n   * For convolutional networks one generally observes reasonable\n     agreement by the time the number of channels is 512.\n\n2. Convergence at small learning rates.\n\nWith a new model it is therefore advisable to start with large width on a small dataset using a small learning rate.\n\n\n## Performance\n\nIn the table below we measure time to compute a single NTK\nentry in a 21-layer CNN (`3x3` filters, no strides, `SAME` padding, `ReLU`) on inputs of shape `3x32x32`. Precisely:\n\n```python\nlayers = []\nfor _ in range(21):\n  layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]\n```\n\n\n### CNN with pooling\n\nTop layer is `stax.GlobalAvgPool()`:\n\n```\n_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))\n```\n\n| Platform                    | Precision | Milliseconds / NTK entry | Max batch size (`NxN`) |\n|-----------------------------|-----------|--------------------------|------------------------|\n| CPU, \u003e56 cores, \u003e700 Gb RAM | 32        |  112.90                  | \u003e= 128                 |\n| CPU, \u003e56 cores, \u003e700 Gb RAM | 64        |  258.55                  |    95 (fastest - 72)   |\n| TPU v2                      | 32/16     |  3.2550                  |    16                  |\n| TPU v3                      | 32/16     |  2.3022                  |    24                  |\n| NVIDIA P100                 | 32        |  5.9433                  |    26                  |\n| NVIDIA P100                 | 64        |  11.349                  |    18                  |\n| NVIDIA V100                 | 32        |  2.7001                  |    26                  |\n| NVIDIA V100                 | 64        |  6.2058                  |    18                  |\n\n\n### CNN without pooling\n\nTop layer is `stax.Flatten()`:\n\n```\n_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))\n```\n\n| Platform                    | Precision | Milliseconds / NTK entry | Max batch size (`NxN`)            |\n|-----------------------------|-----------|--------------------------|-----------------------------------|\n| CPU, \u003e56 cores, \u003e700 Gb RAM | 32        |  0.12013                 |  2048 \u003c= N \u003c 4096 (fastest - 512) |\n| CPU, \u003e56 cores, \u003e700 Gb RAM | 64        |  0.3414                  |  2048 \u003c= N \u003c 4096 (fastest - 256) |\n| TPU v2                      | 32/16     |  0.0015722               |  512  \u003c= N \u003c 1024                 |\n| TPU v3                      | 32/16     |  0.0010647               |  512  \u003c= N \u003c 1024                 |\n| NVIDIA P100                 | 32        |  0.015171                |  512  \u003c= N \u003c 1024                 |\n| NVIDIA P100                 | 64        |  0.019894                |  512  \u003c= N \u003c 1024                 |\n| NVIDIA V100                 | 32        |  0.0046510               |  512  \u003c= N \u003c 1024                 |\n| NVIDIA V100                 | 64        |  0.010822                |  512  \u003c= N \u003c 1024                 |\n\n\n\n\nTested using version `0.2.1`. All GPU results are per single accelerator.\nNote that runtime is proportional to the depth of your network.\nIf your performance differs significantly,\nplease [file a bug](https://github.com/google/neural-tangents/issues/new)!\n\n\n\n### Myrtle network\n\nColab notebook [Performance Benchmark](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/myrtle_kernel_with_neural_tangents.ipynb)\ndemonstrates how one would construct and benchmark kernels. To demonstrate\nflexibility, we took the [Myrtle architecture](https://arxiv.org/2003.02237)\nas an example. With `NVIDIA V100` 64-bit precision, `nt` took 316/330/508 GPU-hours on full 60k CIFAR-10 dataset for Myrtle-5/7/10 kernels.\n\n\n## Citation\n\nIf you use the code in a publication, please cite our papers:\n\n```bibtex\n# Infinite width NTK/NNGP:\n@inproceedings{neuraltangents2020,\n    title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},\n    author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},\n    booktitle={International Conference on Learning Representations},\n    year={2020},\n    pdf={https://arxiv.org/abs/1912.02803},\n    url={https://github.com/google/neural-tangents}\n}\n\n# Finite width, empirical NTK/NNGP:\n@inproceedings{novak2022fast,\n    title={Fast Finite Width Neural Tangent Kernel},\n    author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz},\n    booktitle={International Conference on Machine Learning},\n    year={2022},\n    pdf={https://arxiv.org/abs/2206.08720},\n    url={https://github.com/google/neural-tangents}\n}\n\n# Attention and variable-length inputs:\n@inproceedings{hron2020infinite,\n    title={Infinite attention: NNGP and NTK for deep attention networks},\n    author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak},\n    booktitle={International Conference on Machine Learning},\n    year={2020},\n    pdf={https://arxiv.org/abs/2006.10540},\n    url={https://github.com/google/neural-tangents}\n}\n\n# Infinite-width \"standard\" parameterization:\n@misc{sohl2020on,\n    title={On the infinite width limit of neural networks with a standard parameterization},\n    author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee},\n    publisher = {arXiv},\n    year={2020},\n    pdf={https://arxiv.org/abs/2001.07301},\n    url={https://github.com/google/neural-tangents}\n}\n\n# Elementwise nonlinearities and sketching:\n@inproceedings{han2022fast,\n    title={Fast Neural Kernel Embeddings for General Activations},\n    author={Insu Han and Amir Zandieh and Jaehoon Lee and Roman Novak and Lechao Xiao and Amin Karbasi},\n    booktitle = {Advances in Neural Information Processing Systems},\n    year={2022},\n    pdf={https://arxiv.org/abs/2209.04121},\n    url={https://github.com/google/neural-tangents}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle%2Fneural-tangents","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgoogle%2Fneural-tangents","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle%2Fneural-tangents/lists"}