https://github.com/sayakpaul/bit-jax2tf
This repository hosts the code to port NumPy model weights of BiT-ResNets to TensorFlow SavedModel format.
https://github.com/sayakpaul/bit-jax2tf
bit-resnet computer-vision jax tensorflow
Last synced: about 1 year ago
JSON representation
This repository hosts the code to port NumPy model weights of BiT-ResNets to TensorFlow SavedModel format.
- Host: GitHub
- URL: https://github.com/sayakpaul/bit-jax2tf
- Owner: sayakpaul
- License: apache-2.0
- Created: 2021-08-25T04:25:00.000Z (almost 5 years ago)
- Default Branch: main
- Last Pushed: 2021-12-21T04:26:49.000Z (over 4 years ago)
- Last Synced: 2025-03-31T04:41:12.330Z (over 1 year ago)
- Topics: bit-resnet, computer-vision, jax, tensorflow
- Language: Jupyter Notebook
- Homepage: https://tfhub.dev/sayakpaul/collections/bit-resnet/1
- Size: 14.6 KB
- Stars: 14
- Watchers: 1
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# BiT-jax2tf
This repository hosts the code to port NumPy model weights of BiT-ResNets [1] to TensorFlow SavedModel format. These models
are results of [2]. The original model weights come from [3].
Huge thanks to [Willi Gierke](https://ch.linkedin.com/in/willi-gierke) (of Google) for helping with the porting.
The TensorFlow SavedModels are available on TensorFlow Hub as a collection: https://tfhub.dev/sayakpaul/collections/bit-resnet/1. A total of 8 models are available:
| Model
Name | Input
Resolution | Classifier | Feature
Extractor |
|:---------------: |:-------------------: |:--------------------------------------------------------------------------: |:--------------------------------------------------------------------------: |
| BiT-ResNet152x2 | 384 | [Link](https://tfhub.dev/sayakpaul/bit_resnet152x2_384_classification/1) | [Link](https://tfhub.dev/sayakpaul/bit_r152x2_384_feature_extraction/1) |
| BiT-ResNet152x2 | 224 | [Link](https://tfhub.dev/sayakpaul/bit_resnet152x2_224_classification/1) | [Link](https://tfhub.dev/sayakpaul/bit_r152x2_224_feature_extraction/1) |
| BiT-ResNet50x1 | 224 | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_224_classification/1) | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_224_classification/1) |
| BiT-ResNet50x1 | 160 | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_160_classification/1) | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_160_classification/1) |
You could use the `convert_jax_weights_tf.ipynb` notebook to understand how model porting works between JAX and TensorFlow. There
is also an experimental tool called `jax2tf` from the JAX team that you can find [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
## References
[1] [Big Transfer (BiT): General Visual Representation Learning by Kolesnikov et al.](https://arxiv.org/abs/1912.11370)
[2] [Knowledge distillation: A good teacher is patient and consistent by Beyer et al.](https://arxiv.org/abs/2106.05237)
[3] [BiT GitHub](https://github.com/google-research/big_transfer)