Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/fastai/tf-fit
Fit your tensorflow model using fastai and PyTorch
https://github.com/fastai/tf-fit
Last synced: 8 days ago
JSON representation
Fit your tensorflow model using fastai and PyTorch
- Host: GitHub
- URL: https://github.com/fastai/tf-fit
- Owner: fastai
- License: apache-2.0
- Created: 2018-10-08T19:08:59.000Z (about 6 years ago)
- Default Branch: master
- Last Pushed: 2019-10-19T11:48:15.000Z (about 5 years ago)
- Last Synced: 2024-08-01T15:29:38.550Z (3 months ago)
- Language: Python
- Size: 117 KB
- Stars: 91
- Watchers: 11
- Forks: 16
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- License: LICENSE.txt
Awesome Lists containing this project
README
# fastai-tf-fit
Fit your Tensorflow model using fastai and PyTorch## Installation
```bash
pip install git+https://github.com/fastai/tf-fit.git
```## Features
This project is an extension of fastai to allow training of Tensorflow models with a similar interface of fastai. It uses fastai `DataBunch` objects so the interface is exactly the same for loading data. For training, the `TfLearner` has many of the same features as the fastai `Learner`. Here is a list of the currently supported features.
* Training Tensorflow models with constant learning rate and weight decay
* Training using the [1cycle policy](https://docs.fast.ai/train.html#fit_one_cycle)
* Learning rate finder
* Fit with callbacks with access to hyper parameter updates
* Discriminative learning rates
* Freezing layers from having parameters trained
* [True weight decay option](https://arxiv.org/abs/1711.05101)
* L2 regularization (true_wd=False)
* [Removing weight decay from batchnorm layers option (bn_wd=False)](https://arxiv.org/abs/1706.02677)
* Momentum
* Option to train batchnorm layers even if the layer is frozen (train_bn=True)
* Model saving and loading
* Default image data format is channels * hieght * width## To do
This project is a work in progress so there may be missing features or obscure bugs.
* Get predictions function
* Tensorflow train/eval functionality for dropout and batchnorm in eager mode
* Pip and conda packages## Examples
### Setup
Setup fastai data bunch, optimizer, loss function, and metrics.
```python
from fastai.vision import *
from fastai_tf_fit import *path = untar_data(URLs.CIFAR)
ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])
data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=512).normalize(cifar_stats)opt_fn = tf.train.AdamOptimizer
loss_fn = tf.losses.sparse_softmax_cross_entropy
def categorical_accuracy(y_pred, y_true):
return tf.keras.backend.mean(tf.keras.backend.equal(y_true, tf.keras.backend.argmax(y_pred, axis=-1)))
metrics = [categorical_accuracy]
```### Using tf.keras.Model
```python
class Simple_CNN(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')
self.bn1 = tf.keras.layers.BatchNormalization(axis=1)
self.conv2 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')
self.bn2 = tf.keras.layers.BatchNormalization(axis=1)
self.conv3 = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')
self.bn3 = tf.keras.layers.BatchNormalization(axis=1)
def call(self, xb):
xb = tf.nn.relu(self.bn1(self.conv1(xb)))
xb = tf.nn.relu(self.bn2(self.conv2(xb)))
xb = tf.nn.relu(self.bn3(self.conv3(xb)))
xb = tf.nn.pool(xb, (4,4), 'AVG', 'VALID', data_format="NCHW")
xb = tf.reshape(xb, (-1, 10))
return xbmodel = Simple_CNN()
```### Using Keras functional API
```python
inputs = tf.keras.layers.Input(shape=(3,32,32))
x = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(inputs)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(x)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')(x)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.AveragePooling2D(pool_size=(4, 4), padding='same')(x)
x = tf.keras.layers.Reshape((10,))(x)
predictions = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
```### Training
Create TfLearner object
```python
learn = TfLearner(data, model, opt_fn, loss_fn, metrics=metrics, true_wd=True, bn_wd=True, wd=defaults.wd, train_bn=True)
```Learning rate finder.
```python
learn.lr_find()
learn.recorder.plot()
```Train the model for 3 epochs with a learning rate of 3e-3 and weight decay of 0.4.
```python
learn.fit(3, lr=3e-3, wd=0.4)
```Fit the model using 1cycle policy with a cycle length of 10 using a discriminative learning rate.
```python
learn.fit_one_cycle(10, max_lr=slice(6e-3, 3e-3))
```Freeze, unfreeze, and freeze to last layers from training.
```python
learn.freeze()
```
```python
learn.unfreeze()
```
```python
learn.freeze_to(-1)
```Save and load model weights.
```python
learn.save('cnn-1')
```
```python
learn.load('cnn-1')
```### Metrics
Plot learning rate and momentum schedules.
```python
learn.recorder.plot_lr(show_moms=True)
```Plot train and validation losses.
```python
learn.recorder.plot_losses()
```Plot metrics.
```python
learn.recorder.plot_metrics()
```