{"id":19400993,"url":"https://github.com/google-research/jestimator","last_synced_at":"2025-07-16T09:40:43.037Z","repository":{"id":62129980,"uuid":"550539939","full_name":"google-research/jestimator","owner":"google-research","description":"Amos optimizer with JEstimator lib.","archived":false,"fork":false,"pushed_at":"2024-05-15T15:48:29.000Z","size":3182,"stargazers_count":82,"open_issues_count":1,"forks_count":6,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-06-11T01:53:39.882Z","etag":null,"topics":["deep-learning","flax","jax","language-model","machine-learning","mnist","optimization","optimization-algorithms","optimizer","transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/google-research.png","metadata":{"files":{"readme":"README.md","changelog":"CHANGELOG.md","contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-10-12T23:59:00.000Z","updated_at":"2025-03-25T13:43:43.000Z","dependencies_parsed_at":"2024-01-15T00:05:44.354Z","dependency_job_id":"139ceb00-9334-455b-b59b-fb679308a8cd","html_url":"https://github.com/google-research/jestimator","commit_stats":{"total_commits":40,"total_committers":6,"mean_commits":6.666666666666667,"dds":"0.32499999999999996","last_synced_commit":"6a57d35539f5193a9756a7cb846654e9b221b2e7"},"previous_names":[],"tags_count":2,"template":false,"template_full_name":null,"purl":"pkg:github/google-research/jestimator","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fjestimator","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fjestimator/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fjestimator/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fjestimator/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/google-research","download_url":"https://codeload.github.com/google-research/jestimator/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fjestimator/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":265500620,"owners_count":23777520,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","flax","jax","language-model","machine-learning","mnist","optimization","optimization-algorithms","optimizer","transformer"],"created_at":"2024-11-10T11:16:34.158Z","updated_at":"2025-07-16T09:40:43.014Z","avatar_url":"https://github.com/google-research.png","language":"Python","readme":"# Amos and JEstimator\n\n[![Unittests \u0026 Auto-publish](https://github.com/google-research/jestimator/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-research/jestimator/actions/workflows/pytest_and_autopublish.yml)\n[![PyPI version](https://badge.fury.io/py/jestimator.svg)](https://badge.fury.io/py/jestimator)\n\n*This is not an officially supported Google product.*\n\nThis is the source code for the paper \"[Amos: An Adam-style Optimizer with\nAdaptive Weight Decay towards Model-Oriented\nScale](https://arxiv.org/abs/2210.11693)\".\n\nIt implements **Amos**, an optimizer compatible with the\n[optax](https://github.com/deepmind/optax) library, and **JEstimator**, a\nlight-weight library with a `tf.Estimator`-like interface to manage\n[T5X](https://github.com/google-research/t5x)-compatible checkpoints for machine\nlearning programs in [JAX](https://github.com/google/jax), which we use to run\nexperiments in the paper.\n\n## Quickstart\n\n```\npip install jestimator\n```\n\nIt will install the Amos optimizer implemented in the jestimator lib.\n\n## Usage of Amos\n\nThis implementation of Amos is used with [JAX](https://github.com/google/jax), a\nhigh-performance numerical computing library with automatic differentiation, for\nmachine learning research. The API of Amos is compatible with\n[optax](https://github.com/deepmind/optax), a library of JAX optimizers\n(hopefully Amos will be integrated into optax in the near future).\n\nIn order to demonstrate the usage, we will apply Amos to MNIST. It is based on\nFlax's official\n[MNIST Example](https://github.com/google/flax/tree/main/examples/mnist), and\nyou can find the code in a jupyter notebook\n[here](https://github.com/google-research/jestimator/tree/main/jestimator/models/mnist/mnist.ipynb).\n\n### 1. Imports\n\n```\nimport jax\nimport jax.numpy as jnp                # JAX NumPy\nfrom jestimator import amos            # The Amos optimizer implementation\nfrom jestimator import amos_helper     # Helper module for Amos\n\nfrom flax import linen as nn           # The Linen API\nfrom flax.training import train_state  # Useful dataclass to keep train state\n\nimport math\nimport tensorflow_datasets as tfds     # TFDS for MNIST\nfrom sklearn.metrics import accuracy_score\n```\n\n### 2. Load data\n\n```\ndef get_datasets():\n  \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n\n  ds_builder = tfds.builder('mnist')\n  ds_builder.download_and_prepare()\n  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n  train_ds['image'] = jnp.float32(train_ds['image']) / 255.\n  test_ds['image'] = jnp.float32(test_ds['image']) / 255.\n  return train_ds, test_ds\n```\n\n### 3. Build model\n\n```\nclass CNN(nn.Module):\n  \"\"\"A simple CNN model.\"\"\"\n\n  @nn.compact\n  def __call__(self, x):\n    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = x.reshape((x.shape[0], -1))  # flatten\n    x = nn.Dense(features=256)(x)\n    x = nn.relu(x)\n    x = nn.Dense(features=10)(x)\n    return x\n\n  def classify_xe_loss(self, x, labels):\n    # Labels read from the tfds MNIST are integers from 0 to 9.\n    # Logits are arrays of size 10.\n    logits = self(x)\n    logits = jax.nn.log_softmax(logits)\n    labels_ = jnp.expand_dims(labels, -1)\n    llh_ = jnp.take_along_axis(logits, labels_, axis=-1)\n    loss = -jnp.sum(llh_)\n    return loss\n```\n\n### 4. Create train state\n\nA `TrainState` object keeps the model parameters and optimizer states, and can\nbe checkpointed into files.\n\nWe create the model and optimizer in this function.\n\n**For the optimizer, we use Amos here.** The following hyper-parameters are set:\n\n*   *learning_rate*:       The global learning rate.\n*   *eta_fn*:              The model-specific 'eta'.\n*   *shape_fn*:            Memory reduction setting.\n*   *beta*:                Rate for running average of gradient squares.\n*   *clip_value*:          Gradient clipping for stable training.\n\nThe global learning rate is usually set to the 1/sqrt(N), where N is the number\nof batches in the training data. For MNIST, we have 60k training examples and\nbatch size is 32. So learning_rate=1/sqrt(60000/32).\n\nThe model-specific 'eta_fn' requires a function that, given a variable name and\nshape, returns a float indicating the expected scale of that variable. Hopefully\nin the near future we will have libraries that can automatically calculate this\n'eta_fn' from the modeling code; but for now we have to specify it manually.\n\nOne can use the amos_helper.params_fn_from_assign_map() helper function to\ncreate 'eta_fn' from an assign_map. An assign_map is a dict which maps regex\nrules to a value or simple Python expression. It will find the first regex rule\nwhich matches the name of a variable, and evaluate the Python expression if\nnecessary to return the value. See our example below.\n\nThe 'shape_fn' similarly requires a function that, given a variable name and\nshape, returns a reduced shape for the corresponding slot variables. We can use\nthe amos_helper.params_fn_from_assign_map() helper function to create 'shape_fn'\nfrom an assign_map as well.\n\n'beta' is the exponential decay rate for running average of gradient squares. We\nset it to 0.98 here.\n\n'clip_value' is the gradient clipping value, which should match the magnitude of\nthe loss function. If the loss function is a sum of cross-entropy, then we\nshould set 'clip_value' to the sqrt of the number of labels.\n\nPlease refer to our [paper](https://arxiv.org/abs/2210.11693) for more details\nof the hyper-parameters.\n\n```\ndef get_train_state(rng):\n  model = CNN()\n  dummy_x = jnp.ones([1, 28, 28, 1])\n  params = model.init(rng, dummy_x)\n\n  eta_fn = amos_helper.params_fn_from_assign_map(\n      {\n          '.*/bias': 0.5,\n          '.*Conv_0/kernel': 'sqrt(8/prod(SHAPE[:-1]))',\n          '.*Conv_1/kernel': 'sqrt(2/prod(SHAPE[:-1]))',\n          '.*Dense_0/kernel': 'sqrt(2/SHAPE[0])',\n          '.*Dense_1/kernel': 'sqrt(1/SHAPE[0])',\n      },\n      eval_str_value=True,\n  )\n  shape_fn = amos_helper.params_fn_from_assign_map(\n      {\n          '.*Conv_[01]/kernel': '(1, 1, 1, SHAPE[-1])',\n          '.*Dense_0/kernel': '(1, SHAPE[1])',\n          '.*': (),\n      },\n      eval_str_value=True,\n  )\n  optimizer = amos.amos(\n      learning_rate=1/math.sqrt(60000/32),\n      eta_fn=eta_fn,\n      shape_fn=shape_fn,\n      beta=0.98,\n      clip_value=math.sqrt(32),\n  )\n  return train_state.TrainState.create(\n      apply_fn=model.apply, params=params, tx=optimizer)\n```\n\n### 5. Train step\n\nUse JAX’s @jit decorator to just-in-time compile the function for better\nperformance.\n\n```\n@jax.jit\ndef train_step(batch, state):\n  grad_fn = jax.grad(state.apply_fn)\n  grads = grad_fn(\n      state.params,\n      batch['image'],\n      batch['label'],\n      method=CNN.classify_xe_loss)\n  return state.apply_gradients(grads=grads)\n```\n\n### 6. Infer step\n\nUse JAX’s @jit decorator to just-in-time compile the function for better\nperformance.\n\n```\n@jax.jit\ndef infer_step(batch, state):\n  logits = state.apply_fn(state.params, batch['image'])\n  return jnp.argmax(logits, -1)\n```\n\n### 7. Main\n\nRun the training loop and evaluate on test set.\n\n```\ntrain_ds, test_ds = get_datasets()\n\nrng = jax.random.PRNGKey(0)\nrng, init_rng = jax.random.split(rng)\nstate = get_train_state(init_rng)\ndel init_rng  # Must not be used anymore.\n\nnum_epochs = 9\nfor epoch in range(1, num_epochs + 1):\n  # Use a separate PRNG key to permute image data during shuffling\n  rng, input_rng = jax.random.split(rng)\n  perms = jax.random.permutation(input_rng, 60000)\n  del input_rng\n  perms = perms.reshape((60000 // 32, 32))\n  for perm in perms:\n    batch = {k: v[perm, ...] for k, v in train_ds.items()}\n    state = train_step(batch, state)\n\n  pred = jax.device_get(infer_step(test_ds, state))\n  accuracy = accuracy_score(test_ds['label'], pred)\n  print('epoch: %d, test accuracy: %.2f' % (epoch, accuracy * 100))\n```\n\nAfter 9 epochs, we should get 99.26 test accuracy. If you made it, congrats!\n\n## JEstimator\n\nWith JEstimator, you can build your model mostly similar to the MNIST example\nabove, but without writing code for the \"Main\" section; JEstimator will serve as\nthe entry point for your model, automatically handle checkpointing in a\ntrain/eval-once/eval-while-training-and-save-the-best/predict mode, and set up\nprofiling, tensorboard, and logging.\n\nIn addition, JEstimator supports model partitioning which is required for\ntraining very large models across multiple TPU pods. It supports a\n[T5X](https://github.com/google-research/t5x)-compatible checkpoint format that\nsaves and restores checkpoints in a distributed manner, which is suitable for\nlarge multi-pod models.\n\nIn order to run models with JEstimator, we need to install\n[T5X](https://github.com/google-research/t5x#installation) and\n[FlaxFormer](https://github.com/google/flaxformer):\n\n```\ngit clone --branch=main https://github.com/google-research/t5x\ncd t5x\npython3 -m pip install -e .\ncd ..\n\ngit clone --branch=main https://github.com/google/flaxformer\ncd flaxformer\npip3 install .\ncd ..\n```\n\nThen, clone this repo to get the JEstimator code:\n\n```\ngit clone --branch=main https://github.com/google-research/jestimator\ncd jestimator\n```\n\nNow, we can test a toy linear regression model:\n\n```\nPYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py\n```\n\n## MNIST Example in JEstimator\n\nWe provide this\n[MNIST Example](https://github.com/google-research/jestimator/tree/main/jestimator/models/mnist/mnist.py)\nto demonstrate how to write modeling code with JEstimator. It is much like the\nexample above, but with a big advantage that, a config object is passed around\nto collect information from global flags and the dataset, in order to\ndynamically setup modeling. This makes it easier to apply the model to different datasets; for example, one can immediately try the [emnist](https://www.tensorflow.org/datasets/catalog/emnist) or [eurosat](https://www.tensorflow.org/datasets/catalog/eurosat) datasets simply by changing a command-line argument, without modifying the code.\n\nWith the following command, we can start a job to train on MNIST, log every 100\nsteps, and save the checkpoints to $HOME/experiments/mnist/models:\n\n```\nPYTHONPATH=. python3 jestimator/estimator.py \\\n  --module_imp=\"jestimator.models.mnist.mnist\" \\\n  --module_config=\"jestimator/models/mnist/mnist.py\" \\\n  --train_pattern=\"tfds://mnist/split=train\" \\\n  --model_dir=\"$HOME/experiments/mnist/models\" \\\n  --train_batch_size=32 \\\n  --train_shuffle_buf=4096 \\\n  --train_epochs=9 \\\n  --check_every_steps=100 \\\n  --max_ckpt=20 \\\n  --save_every_steps=1000 \\\n  --module_config.warmup=2000 \\\n  --module_config.amos_beta=0.98\n```\n\nMeanwhile, we can start a job to monitor the $HOME/experiments/mnist/models\nfolder, evaluate on MNIST test set, and save the model with the highest\naccuracy:\n\n```\nPYTHONPATH=. python3 jestimator/estimator.py \\\n  --module_imp=\"jestimator.models.mnist.mnist\" \\\n  --module_config=\"jestimator/models/mnist/mnist.py\" \\\n  --eval_pattern=\"tfds://mnist/split=test\" \\\n  --model_dir=\"$HOME/experiments/mnist/models\" \\\n  --eval_batch_size=32 \\\n  --mode=\"eval_wait\" \\\n  --check_ckpt_every_secs=1 \\\n  --save_high=\"test_accuracy\"\n```\n\nAt the same time, we can start a tensorboard to monitor the process:\n\n```\ntensorboard --logdir $HOME/experiments/mnist/models\n```\n\n## LSTM on PTB\n\nWe can use the following command to train a single layer LSTM on PTB:\n\n```\nPYTHONPATH=. python3 jestimator/estimator.py \\\n  --module_imp=\"jestimator.models.lstm.lm\" \\\n  --module_config=\"jestimator/models/lstm/lm.py\" \\\n  --module_config.vocab_path=\"jestimator/models/lstm/ptb/vocab.txt\" \\\n  --train_pattern=\"jestimator/models/lstm/ptb/ptb.train.txt\" \\\n  --model_dir=\"$HOME/models/ptb_lstm\" \\\n  --train_batch_size=64 \\\n  --train_consecutive=113 \\\n  --train_shuffle_buf=4096 \\\n  --max_train_steps=200000 \\\n  --check_every_steps=1000 \\\n  --max_ckpt=20 \\\n  --module_config.opt_config.optimizer=\"amos\" \\\n  --module_config.opt_config.learning_rate=0.01 \\\n  --module_config.opt_config.beta=0.98 \\\n  --module_config.opt_config.momentum=0.0\n```\n\nand evaluate:\n\n```\nPYTHONPATH=. python3 jestimator/estimator.py \\\n  --module_imp=\"jestimator.models.lstm.lm\" \\\n  --module_config=\"jestimator/models/lstm/lm.py\" \\\n  --module_config.vocab_path=\"jestimator/models/lstm/ptb/vocab.txt\" \\\n  --eval_pattern=\"jestimator/models/lstm/ptb/ptb.valid.txt\" \\\n  --model_dir=\"$HOME/models/ptb_lstm\" \\\n  --eval_batch_size=1\n```\n\nIt is suitable for running on single-GPU machine.\n\n## More JEstimator Models\n\nHere are some simple guides to pre-train and fine-tune BERT-like models, using TPUs on Google Cloud Platform (GCP). One can start with a Web browser with zero setup, by connecting to a Virtual Machine via Google Cloud console, without installing anything locally. If this is the first time, one is covered by enough credits to try the commands by free.\n\n* [My experience of pre-training a BERT-base model on GCP](docs/gcp-pretrain.md)\n* [My experience of fine-tuning MNLI on GCP](docs/gcp-finetune.md)\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-research%2Fjestimator","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgoogle-research%2Fjestimator","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-research%2Fjestimator/lists"}