Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kmkolasinski/tensorflow-nanogpt
Example how to train GPT-2 (XLA + AMP), export to SavedModel and serve with Tensorflow Serving
https://github.com/kmkolasinski/tensorflow-nanogpt
gpt-2 tensorflow xla
Last synced: 22 days ago
JSON representation
Example how to train GPT-2 (XLA + AMP), export to SavedModel and serve with Tensorflow Serving
- Host: GitHub
- URL: https://github.com/kmkolasinski/tensorflow-nanogpt
- Owner: kmkolasinski
- License: mit
- Created: 2023-04-07T10:46:44.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2023-05-06T06:15:37.000Z (over 1 year ago)
- Last Synced: 2024-10-29T18:23:43.589Z (2 months ago)
- Topics: gpt-2, tensorflow, xla
- Language: Jupyter Notebook
- Homepage:
- Size: 602 KB
- Stars: 5
- Watchers: 4
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# tensorflow-nanoGPT
Example code on how to finetune GPT-2 model using Tensorflow and Keras-NLP library then
export it to a fully end-to-end model i.e. text-in => text-out in a SavedModel format, which
later can be served with tensorflow serving. Whole processing is stored in the tensorflow graph,
so no extra libraries / tokenizers are needed to run the inference.The target of this project was to train a generative model to extract Named Entities (NE)
from the input prompt text using model trained on [conll2003](https://huggingface.co/datasets/conll2003) dataset.
The output of this model can be later post processed for further logic.At the end of the notebook you can run trained model in the following way:
```python
import tensorflow as tf
predictor = tf.saved_model.load('/path/to/gpt2/model')prompt = "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY ."
prediction = predictor(prompt)
prediction['outputs'].numpy().decode() == "LEICESTERSHIRE//ORG\n"```
# Main Features
* fast training using **mixed precision**
* even faster training with **XLA enabled (jit_compile)**
* partial model freezing and basic implementation of **LoRA**
* **fast data preparation** by using tokenizer from keras-nlp package (fully compatible with tf.data.Dataset)
* **faster token generation with cached keys/values** tensors of attention head
* export trained model to SavedModel - whole processing is stored inside TF graph (preprocessing, tokenization and prediction)
* example how to serve model using **tensorflow serving**# Some numbers on Google Colab
* Tested on Tesla T4
* I used single but the same prompt for each configuration of the exported model
* GPT-2 - with 256 sequence length| Run Type | Generation time |
|-------------------------|------------------|
| Baseline | 579 ms ± 38.5 ms |
| Baseline + XLA | 369 ms ± 4.37 ms |
| Cached Key/Values | 688 ms ± 125 ms |
| Cached Key/Values + XLA | 245 ms ± 7.72 ms |# Installation
* See [requirements.txt](requirements.txt) file
* Tested with Tensorflow 2.11
* Checkout example notebook [gpt_2_finetune_conll2003.ipynb](gpt_2_finetune_conll2003.ipynb)# Disclaimer
* the aim of this project was not to create any form of SOTA model,
* this is just a test / demo of various features of TensorFlow library,
* the notebook shows how to go from data to production ready and servable model,
* I'm aware that there are still many things which to try to improve the throughput and memory usage.