{"id":13653463,"url":"https://github.com/google-research/big_transfer","last_synced_at":"2025-05-15T18:06:13.630Z","repository":{"id":37660128,"uuid":"263163293","full_name":"google-research/big_transfer","owner":"google-research","description":"Official repository for the \"Big Transfer (BiT): General Visual Representation Learning\" paper.","archived":false,"fork":false,"pushed_at":"2024-07-30T21:21:39.000Z","size":832,"stargazers_count":1523,"open_issues_count":43,"forks_count":177,"subscribers_count":40,"default_branch":"master","last_synced_at":"2025-03-31T21:51:18.357Z","etag":null,"topics":["convolutional-neural-networks","deep-learning","imagenet","jax","pytorch","tensorflow2","transfer-learning"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/1912.11370","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/google-research.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2020-05-11T21:34:13.000Z","updated_at":"2025-03-27T09:33:56.000Z","dependencies_parsed_at":"2024-09-21T04:40:26.122Z","dependency_job_id":null,"html_url":"https://github.com/google-research/big_transfer","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fbig_transfer","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fbig_transfer/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fbig_transfer/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fbig_transfer/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/google-research","download_url":"https://codeload.github.com/google-research/big_transfer/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247737788,"owners_count":20987721,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["convolutional-neural-networks","deep-learning","imagenet","jax","pytorch","tensorflow2","transfer-learning"],"created_at":"2024-08-02T02:01:10.801Z","updated_at":"2025-04-07T22:11:17.749Z","avatar_url":"https://github.com/google-research.png","language":"Python","readme":"## Big Transfer (BiT): General Visual Representation Learning\n*by Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby*\n\n\n**Update 18/06/2021:** We release new high performing BiT-R50x1 models, which were distilled from BiT-M-R152x2, see [this section](#distilled-models). More details in our [paper \"Knowledge distillation: A good teacher is patient and consistent\"](https://arxiv.org/abs/2106.05237).\n\n**Update 08/02/2021:** We also release ALL BiT-M models fine-tuned on ALL 19 VTAB-1k datasets, see below.\n\n## Introduction\n\nIn this repository we release multiple models from the [Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370) paper that were pre-trained on the [ILSVRC-2012](http://www.image-net.org/challenges/LSVRC/2012/) and [ImageNet-21k](http://www.image-net.org/) datasets.\nWe provide the code to fine-tuning the released models in the major deep learning frameworks [TensorFlow 2](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/) and [Jax](https://jax.readthedocs.io/en/latest/index.html)/[Flax](http://flax.readthedocs.io).\n\nWe hope that the computer vision community will benefit by employing more powerful ImageNet-21k pretrained models as opposed to conventional models pre-trained on the ILSVRC-2012 dataset.\n\nWe also provide colabs for a more exploratory interactive use:\na [TensorFlow 2 colab](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb),\na [PyTorch colab](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_pytorch.ipynb),\nand a [Jax colab](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_jax.ipynb).\n\n## Installation\n\nMake sure you have `Python\u003e=3.6` installed on your machine.\n\nTo setup [Tensorflow 2](https://github.com/tensorflow/tensorflow), [PyTorch](https://github.com/pytorch/pytorch) or [Jax](https://github.com/google/jax), follow the instructions provided in the corresponding repository linked here.\n\nIn addition, install python dependencies by running (please select `tf2`, `pytorch` or `jax` in the command below):\n```\npip install -r bit_{tf2|pytorch|jax}/requirements.txt\n```\n\n## How to fine-tune BiT\nFirst, download the BiT model. We provide models pre-trained on ILSVRC-2012 (BiT-S) or ImageNet-21k (BiT-M) for 5 different architectures: ResNet-50x1, ResNet-101x1, ResNet-50x3, ResNet-101x3, and ResNet-152x4.\n\nFor example, if you would like to download the ResNet-50x1 pre-trained on ImageNet-21k, run the following command:\n```\nwget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}\n```\nOther models can be downloaded accordingly by plugging the name of the model (BiT-S or BiT-M) and architecture in the above command.\nNote that we provide models in two formats: `npz` (for PyTorch and Jax) and `h5` (for TF2). By default we expect that model weights are stored in the root folder of this repository.\n\nThen, you can run fine-tuning of the downloaded model on your dataset of interest in any of the three frameworks. All frameworks share the command line interface\n```\npython3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10\n```\nCurrently. all frameworks will automatically download CIFAR-10 and CIFAR-100 datasets. Other public or custom datasets can be easily integrated: in TF2 and JAX we rely on the extensible [tensorflow datasets library](https://github.com/tensorflow/datasets/). In PyTorch, we use [torchvision’s data input pipeline](https://pytorch.org/docs/stable/torchvision/index.html).\n\nNote that our code uses all available GPUs for fine-tuning.\n\nWe also support training in the low-data regime: the `--examples_per_class \u003cK\u003e` option will randomly draw K samples per class for training.\n\nTo see a detailed list of all available flags, run `python3 -m bit_{pytorch|jax|tf2}.train --help`.\n\n### BiT-M models fine-tuned on ILSVRC-2012\n\nFor convenience, we provide BiT-M models that were already fine-tuned on the\nILSVRC-2012 dataset. The models can be downloaded by adding the `-ILSVRC2012`\npostfix, e.g.\n```\nwget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz\n```\n\n### Available architectures\n\nWe release all architectures mentioned in the paper, such that you may choose between accuracy or speed: R50x1, R101x1, R50x3, R101x3, R152x4.\nIn the above path to the model file, simply replace `R50x1` by your architecture of choice.\n\nWe further investigated more architectures after the paper's publication and found R152x2 to have a nice trade-off between speed and accuracy, hence we also include this in the release and provide a few numbers below.\n\n\n### BiT-M models fine-tuned on the 19 VTAB-1k tasks\n\nWe also release the fine-tuned models for each of the 19 tasks included in the VTAB-1k benchmark. We ran each model three times and release each of these runs. This means we release a total of 5x19x3=285 models, and hope these can be useful in further analysis of transfer learning.\n\nThe files can be downloaded via the following pattern:\n\n```\nwget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz\n```\n\nWe did not convert these models to TF2 (hence there is no corresponding `.h5` file), however, we also uploaded [TFHub](http://tfhub.dev) models which can be used in TF1 and TF2. An example sequence of commands for downloading one such model is:\n\n```\nmkdir BiT-M-R50x1-run0-caltech101.tfhub \u0026\u0026 cd BiT-M-R50x1-run0-caltech101.tfhub\nwget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}\nmkdir variables \u0026\u0026 cd variables\nwget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}\n```\n\n### Hyper-parameters\n\nFor reproducibility, our training script uses hyper-parameters (BiT-HyperRule) that were used in the original paper.\nNote, however, that BiT models were trained and finetuned using Cloud TPU hardware, so for a typical GPU setup our default hyper-parameters could require too much memory or result in a very slow progress.\nMoreover, BiT-HyperRule is designed to generalize across many datasets, so it is typically possible to devise more efficient application-specific hyper-parameters.\nThus, we encourage the user to try more light-weight settings, as they require much less resources and often result in a similar accuracy.\n\nFor example, we tested our code using a 8xV100 GPU machine on the CIFAR-10 and CIFAR-100 datasets, while reducing batch size from 512 to 128 and learning rate from 0.003 to 0.001.\nThis setup resulted in nearly identical performance (see [Expected results](#expected-results) below) in comparison to BiT-HyperRule, despite being less computationally demanding.\n\nBelow, we provide more suggestions on how to optimize our paper's setup.\n\n### Tips for optimizing memory or speed\n\nThe default BiT-HyperRule was developed on Cloud TPUs and is quite memory-hungry.\nThis is mainly due to the large batch-size (512) and image resolution (up to 480x480).\nHere are some tips if you are running out of memory:\n\n  1. In `bit_hyperrule.py` we specify the input resolution.\n     By reducing it, one can save a lot of memory and compute, at the expense of accuracy.\n  2. The batch-size can be reduced in order to reduce memory consumption.\n     However, one then also needs to play with learning-rate and schedule (steps) in order to maintain the desired accuracy.\n  3. The PyTorch codebase supports a batch-splitting technique (\"micro-batching\") via `--batch_split` option.\n     For example, running the fine-tuning with `--batch_split 8` reduces memory requirement by a factor of 8.\n\n## Expected results\n\nWe verified that when using the BiT-HyperRule, the code in this repository reproduces the paper's results.\n\n### CIFAR results (few-shot and full)\n\nFor these common benchmarks, the aforementioned changes to the BiT-HyperRule (`--batch 128 --base_lr 0.001`) lead to the following, very similar results.\nThe table shows the min←**median**→max result of at least five runs.\n**NOTE**: This is not a comparison of frameworks, just evidence that all code-bases can be trusted to reproduce results.\n\n#### BiT-M-R101x3\n\n| Dataset  | Ex/cls |          TF2           |          Jax           |         PyTorch        |\n| :---     | :---:  |         :---:          |         :---:          |          :---:         |\n| CIFAR10  |   1    | 52.5 ← **55.8** → 60.2 | 48.7 ← **53.9** → 65.0 | 56.4 ← **56.7** → 73.1 |\n| CIFAR10  |   5    | 85.3 ← **87.2** → 89.1 | 80.2 ← **85.8** → 88.6 | 84.8 ← **85.8** → 89.6 |\n| CIFAR10  |  full  |        **98.5**        |        **98.4**        | 98.5 ← **98.6** → 98.6 |\n| CIFAR100 |   1    | 34.8 ← **35.7** → 37.9 | 32.1 ← **35.0** → 37.1 | 31.6 ← **33.8** → 36.9 |\n| CIFAR100 |   5    | 68.8 ← **70.4** → 71.4 | 68.6 ← **70.8** → 71.6 | 70.6 ← **71.6** → 71.7 |\n| CIFAR100 |  full  |        **90.8**        |        **91.2**        | 91.1 ← **91.2** → 91.4 |\n\n#### BiT-M-R152x2\n\n| Dataset  | Ex/cls |           Jax          |         PyTorch        |\n| :---     | :---:  |          :---:         |          :---:         |\n| CIFAR10  |   1    | 44.0 ← **56.7** → 65.0 | 50.9 ← **55.5** → 59.5 |\n| CIFAR10  |   5    | 85.3 ← **87.0** → 88.2 | 85.3 ← **85.8** → 88.6 |\n| CIFAR10  |  full  |        **98.5**        | 98.5 ← **98.5** → 98.6 |\n| CIFAR100 |   1    | 36.4 ← **37.2** → 38.9 | 34.3 ← **36.8** → 39.0 |\n| CIFAR100 |   5    | 69.3 ← **70.5** → 72.0 | 70.3 ← **72.0** → 72.3 |\n| CIFAR100 |  full  |        **91.2**        | 91.2 ← **91.3** → 91.4 |\n\n(TF2 models not yet available.)\n\n#### BiT-M-R50x1\n\n| Dataset  | Ex/cls |          TF2           |          Jax           |         PyTorch        |\n| :---     | :---:  |         :---:          |         :---:          |          :---:         |\n| CIFAR10  |   1    | 49.9 ← **54.4** → 60.2 | 48.4 ← **54.1** → 66.1 | 45.8 ← **57.9** → 65.7 |\n| CIFAR10  |   5    | 80.8 ← **83.3** → 85.5 | 76.7 ← **82.4** → 85.4 | 80.3 ← **82.3** → 84.9 |\n| CIFAR10  |  full  |        **97.2**        |        **97.3**        |        **97.4**        |\n| CIFAR100 |   1    | 35.3 ← **37.1** → 38.2 | 32.0 ← **35.2** → 37.8 | 34.6 ← **35.2** → 38.6 |\n| CIFAR100 |   5    | 63.8 ← **65.0** → 66.5 | 63.4 ← **64.8** → 66.5 | 64.7 ← **65.5** → 66.0 |\n| CIFAR100 |  full  |        **86.5**        |        **86.4**        |        **86.6**        |\n\n### ImageNet results\n\nThese results were obtained using BiT-HyperRule.\nHowever, because this results in large batch-size and large resolution, memory can be an issue.\nThe PyTorch code supports batch-splitting, and hence we can still run things there without resorting to Cloud TPUs by adding the `--batch_split N` command where `N` is a power of two.\nFor instance, the following command produces a validation accuracy of `80.68` on a machine with 8 V100 GPUs:\n\n```\npython3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4\n```\n\nFurther increase to `--batch_split 8` when running with 4 V100 GPUs, etc.\n\nFull results achieved that way in some test runs were:\n\n| Ex/cls | R50x1 | R152x2 | R101x3 |\n| :---:  | :---: | :---:  | :---:  |\n|   1    | 18.36 | 24.5   | 25.55  |\n|   5    | 50.64 | 64.5   | 64.18  |\n|  full  | 80.68 | 85.15  | WIP    |\n\n### VTAB-1k results\n\nThese are re-runs and not the exact paper models. The expected VTAB scores for two of the models are:\n\n| Model         | Full  | Natural | Structured | Specialized |\n| :---          | :---: |  :---:  |   :---:    |    :---:    |\n| BiT-M-R152x4  | 73.51 |  80.77  |    61.08   |    85.67    |\n| BiT-M-R101x3  | 72.65 |  80.29  |    59.40   |    85.75    |\n\n## Out of context dataset\n\nIn Appendix G of our paper, we investigate whether BiT improves out-of-context robustness.\nTo do this, we created a dataset comprising foreground objects corresponding to 21 ILSVRC-2012 classes pasted onto 41 miscellaneous backgrounds.\n\nTo download the dataset, run\n\n```\nwget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip\n```\n\nImages from each of the 21 classes are kept in a directory with the name of the class.\n\n## Distilled models\n\nWe release top-performing compressed BiT models from our [paper \"Knowledge distillation: A good teacher is patient and consistent\"](https://arxiv.org/abs/2106.05237) on knoweldge distillation.\nIn particular, we distill the BiT-M-R152x2 model (which was pre-trained on ImageNet-21k) to BiT-R50x1 models.\nAs a result, we obtain compact models with very competitive performance.\n\n| Model      | Download link | Resolution  | ImageNet top-1 acc. (paper) | \n| :---       | :---:         | :---:       |  :---:                      |\n| BiT-R50x1  | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz)      | 224 |  82.8 |\n| BiT-R50x1  | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz)      | 160 |  80.5 |\n\nFor reproducibility, we also release weights of two BiT-M-R152x2 teacher models: pretrained at [resolution 224](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) and [resolution 384](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz). See the paper for details on how these teachers were used.\n\n### Distillation code\n\nWe have no concrete plans for publishing the distillation code, as the recipe is simple and we imagine most people would integrate it in their existing training code.\nHowever, [Sayak Paul](https://sayak.dev/) has independently [re-implemented the distillation setup in TensorFlow](https://github.com/sayakpaul/FunMatch-Distillation) and nearly reproduced our results in several settings.\n","funding_links":[],"categories":["Sensor Processing","Python","DLA","Models and Projects"],"sub_categories":["Image Processing","BiT","Flax"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-research%2Fbig_transfer","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgoogle-research%2Fbig_transfer","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-research%2Fbig_transfer/lists"}