{"id":13578588,"url":"https://github.com/fastai/tf-fit","last_synced_at":"2025-12-24T08:18:19.182Z","repository":{"id":66065752,"uuid":"152132208","full_name":"fastai/tf-fit","owner":"fastai","description":"Fit your tensorflow model using fastai and PyTorch","archived":false,"fork":false,"pushed_at":"2019-10-19T11:48:15.000Z","size":120,"stargazers_count":91,"open_issues_count":4,"forks_count":16,"subscribers_count":10,"default_branch":"master","last_synced_at":"2025-04-23T15:03:33.055Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/fastai.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE.txt","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null}},"created_at":"2018-10-08T19:08:59.000Z","updated_at":"2024-04-15T17:48:46.000Z","dependencies_parsed_at":null,"dependency_job_id":"21ae92bd-4a8e-4418-a999-bc17449dde7c","html_url":"https://github.com/fastai/tf-fit","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/fastai/tf-fit","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fastai%2Ftf-fit","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fastai%2Ftf-fit/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fastai%2Ftf-fit/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fastai%2Ftf-fit/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/fastai","download_url":"https://codeload.github.com/fastai/tf-fit/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fastai%2Ftf-fit/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":27998479,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","status":"online","status_checked_at":"2025-12-24T02:00:07.193Z","response_time":83,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-08-01T15:01:32.029Z","updated_at":"2025-12-24T08:18:19.158Z","avatar_url":"https://github.com/fastai.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"# fastai-tf-fit\nFit your Tensorflow model using fastai and PyTorch\n\n## Installation\n```bash\npip install git+https://github.com/fastai/tf-fit.git\n```\n\n## Features\nThis project is an extension of fastai to allow training of Tensorflow models with a similar interface of fastai. It uses fastai `DataBunch` objects so the interface is exactly the same for loading data. For training, the `TfLearner` has many of the same features as the fastai `Learner`. Here is a list of the currently supported features.\n* Training Tensorflow models with constant learning rate and weight decay\n* Training using the [1cycle policy](https://docs.fast.ai/train.html#fit_one_cycle)\n* Learning rate finder\n* Fit with callbacks with access to hyper parameter updates\n* Discriminative learning rates\n* Freezing layers from having parameters trained\n* [True weight decay option](https://arxiv.org/abs/1711.05101)\n* L2 regularization (true_wd=False)\n* [Removing weight decay from batchnorm layers option (bn_wd=False)](https://arxiv.org/abs/1706.02677)\n* Momentum\n* Option to train batchnorm layers even if the layer is frozen (train_bn=True)\n* Model saving and loading\n* Default image data format is channels * hieght * width\n\n## To do\nThis project is a work in progress so there may be missing features or obscure bugs.\n* Get predictions function\n* Tensorflow train/eval functionality for dropout and batchnorm in eager mode\n* Pip and conda packages\n\n## Examples\n\n### Setup\nSetup fastai data bunch, optimizer, loss function, and metrics.\n```python\nfrom fastai.vision import *\nfrom fastai_tf_fit import *\n\npath = untar_data(URLs.CIFAR)\nds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])\ndata = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=512).normalize(cifar_stats)\n\nopt_fn = tf.train.AdamOptimizer\n\nloss_fn = tf.losses.sparse_softmax_cross_entropy\n\ndef categorical_accuracy(y_pred, y_true):\n    return tf.keras.backend.mean(tf.keras.backend.equal(y_true, tf.keras.backend.argmax(y_pred, axis=-1)))\nmetrics = [categorical_accuracy]\n```\n\n### Using tf.keras.Model\n```python\nclass Simple_CNN(tf.keras.Model):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')\n        self.bn1 = tf.keras.layers.BatchNormalization(axis=1)\n        self.conv2 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')\n        self.bn2 = tf.keras.layers.BatchNormalization(axis=1)\n        self.conv3 = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')\n        self.bn3 = tf.keras.layers.BatchNormalization(axis=1)\n    def call(self, xb):\n        xb = tf.nn.relu(self.bn1(self.conv1(xb)))\n        xb = tf.nn.relu(self.bn2(self.conv2(xb)))\n        xb = tf.nn.relu(self.bn3(self.conv3(xb)))\n        xb = tf.nn.pool(xb, (4,4), 'AVG', 'VALID', data_format=\"NCHW\")\n        xb = tf.reshape(xb, (-1, 10))\n        return xb\n\nmodel = Simple_CNN()\n```\n\n\n\n### Using Keras functional API\n```python\ninputs = tf.keras.layers.Input(shape=(3,32,32))\nx = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(inputs)\nx = tf.keras.layers.BatchNormalization(axis=1)(x)\nx = tf.keras.layers.Activation(\"relu\")(x)\nx = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(x)\nx = tf.keras.layers.BatchNormalization(axis=1)(x)\nx = tf.keras.layers.Activation(\"relu\")(x)\nx = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')(x)\nx = tf.keras.layers.BatchNormalization(axis=1)(x)\nx = tf.keras.layers.Activation(\"relu\")(x)\nx = tf.keras.layers.AveragePooling2D(pool_size=(4, 4), padding='same')(x)\nx = tf.keras.layers.Reshape((10,))(x)\npredictions = tf.keras.layers.Dense(10, activation='softmax')(x)\nmodel = tf.keras.models.Model(inputs=inputs, outputs=predictions)\n```\n\n### Training\nCreate TfLearner object\n```python\nlearn = TfLearner(data, model, opt_fn, loss_fn, metrics=metrics, true_wd=True, bn_wd=True, wd=defaults.wd, train_bn=True)\n```\n\nLearning rate finder.\n```python\nlearn.lr_find()\nlearn.recorder.plot()\n```\n\nTrain the model for 3 epochs with a learning rate of 3e-3 and weight decay of 0.4.\n```python\nlearn.fit(3, lr=3e-3, wd=0.4)\n```\n\nFit the model using 1cycle policy with a cycle length of 10 using a discriminative learning rate.\n```python\nlearn.fit_one_cycle(10, max_lr=slice(6e-3, 3e-3))\n```\n\nFreeze, unfreeze, and freeze to last layers from training.\n```python\nlearn.freeze()\n```\n```python\nlearn.unfreeze()\n```\n```python\nlearn.freeze_to(-1)\n```\n\nSave and load model weights.\n```python\nlearn.save('cnn-1')\n```\n```python\nlearn.load('cnn-1')\n```\n\n### Metrics\nPlot learning rate and momentum schedules.\n```python\nlearn.recorder.plot_lr(show_moms=True)\n```\n\nPlot train and validation losses.\n```python\nlearn.recorder.plot_losses()\n```\n\nPlot metrics.\n```python\nlearn.recorder.plot_metrics()\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ffastai%2Ftf-fit","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ffastai%2Ftf-fit","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ffastai%2Ftf-fit/lists"}