{"id":13635810,"url":"https://github.com/google-deepmind/kfac-jax","last_synced_at":"2025-12-14T20:40:18.760Z","repository":{"id":37073978,"uuid":"471322163","full_name":"google-deepmind/kfac-jax","owner":"google-deepmind","description":"Second Order Optimization and Curvature Estimation with K-FAC in JAX.","archived":false,"fork":false,"pushed_at":"2025-06-12T19:16:41.000Z","size":880,"stargazers_count":278,"open_issues_count":22,"forks_count":26,"subscribers_count":11,"default_branch":"main","last_synced_at":"2025-06-12T20:47:45.997Z","etag":null,"topics":["bayesian-deep-learning","machine-learning","optimization"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/google-deepmind.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2022-03-18T10:19:24.000Z","updated_at":"2025-06-12T19:16:43.000Z","dependencies_parsed_at":"2023-09-07T20:34:51.950Z","dependency_job_id":"73e7497d-325a-4238-9c32-10ee15359fbe","html_url":"https://github.com/google-deepmind/kfac-jax","commit_stats":null,"previous_names":["google-deepmind/kfac-jax","deepmind/kfac-jax"],"tags_count":6,"template":false,"template_full_name":null,"purl":"pkg:github/google-deepmind/kfac-jax","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fkfac-jax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fkfac-jax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fkfac-jax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fkfac-jax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/google-deepmind","download_url":"https://codeload.github.com/google-deepmind/kfac-jax/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fkfac-jax/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":260268635,"owners_count":22983601,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["bayesian-deep-learning","machine-learning","optimization"],"created_at":"2024-08-02T00:00:52.314Z","updated_at":"2025-12-14T20:40:18.690Z","avatar_url":"https://github.com/google-deepmind.png","language":"Python","readme":"# KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX\n\n[**Installation**](#installation)\n| [**Quickstart**](#quickstart)\n| [**Documentation**](https://kfac-jax.readthedocs.io/)\n| [**Examples**](https://github.com/google-deepmind/kfac-jax/tree/main/examples/)\n| [**Citing KFAC-JAX**](#citing-kfac-jax)\n\n![CI status](https://github.com/google-deepmind/kfac-jax/workflows/ci/badge.svg)\n![docs](https://readthedocs.org/projects/kfac-jax/badge/?version=latest)\n![pypi](https://img.shields.io/pypi/v/kfac-jax)\n\nKFAC-JAX is a library built on top of [JAX] for second-order optimization of\nneural networks and for computing scalable curvature approximations.\nThe main goal of the library is to provide researchers with an easy-to-use\nimplementation of the [K-FAC] optimizer and curvature estimator.\n\n## Installation\u003ca id=\"installation\"\u003e\u003c/a\u003e\n\nKFAC-JAX is written in pure Python, but depends on C++ code via JAX.\n\nFirst, follow [these instructions](https://github.com/google/jax#installation)\nto install JAX with the relevant accelerator support.\n\nThen, install KFAC-JAX using pip:\n\n```bash\n$ pip install git+https://github.com/google-deepmind/kfac-jax\n```\n\nAlternatively, you can install via PyPI:\n\n```bash\n$ pip install -U kfac-jax\n```\n\nOur examples rely on additional libraries, all of which you can install using:\n\n```bash\n$ pip install kfac-jax[examples]\n```\n\n## Quickstart\u003ca id=\"quickstart\"\u003e\u003c/a\u003e\n\nLet's take a look at a simple example of training a neural network, defined\nusing [Haiku], with the K-FAC optimizer:\n\n```python\nimport haiku as hk\nimport jax\nimport jax.numpy as jnp\nimport kfac_jax\n\n# Hyper parameters\nNUM_CLASSES = 10\nL2_REG = 1e-3\nNUM_BATCHES = 100\n\n\ndef make_dataset_iterator(batch_size):\n  # Dummy dataset, in practice this should be your dataset pipeline\n  for _ in range(NUM_BATCHES):\n    yield jnp.zeros([batch_size, 100]), jnp.ones([batch_size], dtype=\"int32\")\n\n\ndef softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):\n  \"\"\"Softmax cross entropy loss.\"\"\"\n  # We assume integer labels\n  assert logits.ndim == targets.ndim + 1\n\n  # Tell KFAC-JAX this model represents a classifier\n  # See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses\n  kfac_jax.register_softmax_cross_entropy_loss(logits, targets)\n  log_p = jax.nn.log_softmax(logits, axis=-1)\n  return - jax.vmap(lambda x, y: x[y])(log_p, targets)\n\n\ndef model_fn(x):\n  \"\"\"A Haiku MLP model function - three hidden layer network with tanh.\"\"\"\n  return hk.nets.MLP(\n    output_sizes=(50, 50, 50, NUM_CLASSES),\n    with_bias=True,\n    activation=jax.nn.tanh,\n  )(x)\n\n\n# The Haiku transformed model\nhk_model = hk.without_apply_rng(hk.transform(model_fn))\n\n\ndef loss_fn(model_params, model_batch):\n  \"\"\"The loss function to optimize.\"\"\"\n  x, y = model_batch\n  logits = hk_model.apply(model_params, x)\n  loss = jnp.mean(softmax_cross_entropy(logits, y))\n\n  # The optimizer assumes that the function you provide has already added\n  # the L2 regularizer to its gradients.\n  return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0\n\n\n# Create the optimizer\noptimizer = kfac_jax.Optimizer(\n  value_and_grad_func=jax.value_and_grad(loss_fn),\n  l2_reg=L2_REG,\n  value_func_has_aux=False,\n  value_func_has_state=False,\n  value_func_has_rng=False,\n  use_adaptive_learning_rate=True,\n  use_adaptive_momentum=True,\n  use_adaptive_damping=True,\n  initial_damping=1.0,\n  multi_device=False,\n)\n\ninput_dataset = make_dataset_iterator(128)\nrng = jax.random.PRNGKey(42)\ndummy_images, dummy_labels = next(input_dataset)\nrng, key = jax.random.split(rng)\nparams = hk_model.init(key, dummy_images)\nrng, key = jax.random.split(rng)\nopt_state = optimizer.init(params, key, (dummy_images, dummy_labels))\n\n# Training loop\nfor i, batch in enumerate(input_dataset):\n  rng, key = jax.random.split(rng)\n  params, opt_state, stats = optimizer.step(\n      params, opt_state, key, batch=batch, global_step_int=i)\n  print(i, stats)\n```\n\n### Do not stage (``jit`` or ``pmap``) the optimizer\n\nYou should not apply `jax.jit` or `jax.pmap` to the call to `Optimizer.step`.\nThis is already done for you automatically by the optimizer class.\nTo control the staging behaviour of the optimizer set the flag ``multi_device``\nto ``True`` for ``pmap`` and to ``False`` for ``jit``.\n\n### Do not stage (``jit`` or ``pmap``) the loss function\n\nThe ``value_and_grad_func`` argument provided to the optimizer should compute\nthe loss function value and its gradients. Since the optimizer already stages\nits step function internally, applying ``jax.jit`` to ``value_and_grad_func`` is\n**NOT** recommended.\nImportantly, applying ``jax.pmap`` is **WRONG** and most likely will lead to\nerrors.\n\n### Registering the model loss function\n\nIn order for KFAC-JAX to be able to correctly approximate the curvature matrix\nof the model it needs to know the precise loss function that you want to\noptimize.\nThis is done via registration with certain functions provided by the library.\nFor instance, in the example above this is done via the call to\n``kfac_jax.register_softmax_cross_entropy_loss``, which tells the optimizer that\nthe loss is the standard softmax cross-entropy.\nIf you don't do this you will get an error when you try to call the optimizer.\nFor all supported loss functions please read the [documentation].\n\n``Important:`` The optimizer assumes that the loss is averaged over examples in\nthe minibatch. It is crucial that you follow this convention.\n\n### Other model function options\n\nOftentimes, one will want to output some auxiliary statistics or metrics in\naddition to the loss value.\nThis can already be done in the ``value_and_grad_func``, in which case we follow\nthe same conventions as JAX and expect the output to be ``(loss, aux), grads``.\nSimilarly, the loss function can take an additional function state (batch norm\nlayers usually have this) or an PRNG key (used in stochastic layers). All of\nthese, however, need to be explicitly told to the optimizer via its arguments\n``value_func_has_aux``, ``value_func_has_state`` and ``value_func_has_rng``.\n\n### Verify optimizer registrations\n\nWe strongly encourage the user to pay attention to the logging messages produced\nby the automatic registration system, in order to ensure that it has correctly\nunderstood your model.\nFor the example above this looks like this:\n\n```python\n==================================================\nGraph parameter registrations:\n{'mlp/~/linear_0': {'b': 'Auto[dense_with_bias_3]',\n                    'w': 'Auto[dense_with_bias_3]'},\n 'mlp/~/linear_1': {'b': 'Auto[dense_with_bias_2]',\n                    'w': 'Auto[dense_with_bias_2]'},\n 'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]',\n                    'w': 'Auto[dense_with_bias_1]'},\n 'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',\n                    'w': 'Auto[dense_with_bias_0]'}}\n==================================================\n```\n\nAs can be seen from this message, the library has correctly detected all\nparameters of the model to be part of dense layers.\n\n### Further reading\nFor a high level overview of the optimizer, the different curvature\napproximations, and the supported layers, please see the [documentation].\n\n## Citing KFAC-JAX\u003ca id=\"citing-kfac-jax\"\u003e\u003c/a\u003e\n\nTo cite this repository:\n\n```\n@software{kfac-jax2022github,\n  author = {Aleksandar Botev and James Martens},\n  title = {{KFAC-JAX}},\n  url = {https://github.com/google-deepmind/kfac-jax},\n  version = {0.0.2},\n  year = {2022},\n}\n```\n\nIn this bibtex entry, the version number is intended to be from\n[`kfac_jax/__init__.py`](https://github.com/google-deepmind/kfac-jax/blob/main/kfac_jax/__init__.py),\nand the year corresponds to the project's open-source release.\n\n\n[K-FAC]: https://arxiv.org/abs/1503.05671\n[JAX]: https://github.com/google/jax\n[Haiku]: https://github.com/google-deepmind/dm-haiku\n[documentation]: https://kfac-jax.readthedocs.io/\n","funding_links":[],"categories":["Mathematical tools","Implementation in JAX","Libraries"],"sub_categories":["Other"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-deepmind%2Fkfac-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgoogle-deepmind%2Fkfac-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-deepmind%2Fkfac-jax/lists"}