Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/hrbigelow/transformer-aiayn
The Transformer from "Attention is All You Need"
https://github.com/hrbigelow/transformer-aiayn
artificial-intelligence deep-learning haiku jax
Last synced: 4 days ago
JSON representation
The Transformer from "Attention is All You Need"
- Host: GitHub
- URL: https://github.com/hrbigelow/transformer-aiayn
- Owner: hrbigelow
- Created: 2023-04-19T22:48:09.000Z (over 1 year ago)
- Default Branch: master
- Last Pushed: 2024-10-29T06:01:34.000Z (18 days ago)
- Last Synced: 2024-10-29T07:17:05.994Z (18 days ago)
- Topics: artificial-intelligence, deep-learning, haiku, jax
- Language: Python
- Homepage:
- Size: 5.28 MB
- Stars: 1
- Watchers: 3
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Introduction
An original implementation of the paper [Attention is All You
Need](https://arxiv.org/pdf/1706.03762.pdf) by Vaswani et al.![Loss (Conditional KL-Divergence in bits](assets/metrics.png)
![Loss (Conditional KL-Divergence in bits (zoom)](assets/metrics_zoom.png)
Shown above is a training run for about 260k steps (about 16 epochs per 100k steps).
Blue: perplexity on WMT14/en-de 4.5 M training, with dropout. Orange: perplexity on
newstest2013, with dropout. Green: perplexity on newstest2013, no dropout. Red:
Bleu score on newstest2013 set, no dropout. newstest2013 has 3000
sentences.![Learning rate](assets/jul18-lr-40k.png)
The learning rate schedule is as given in the paper, section 5.3, page 7, equation 3.
```python
def make_learning_rate_fn(warmup_steps, M):
# from section 5.3, page 7, equation 3
def lr_fn(step):
factor = jax.lax.min(step ** -0.5, step * warmup_steps ** -1.5)
new_lr = M ** -0.5 * factor
# jax.debug.print('learn_rate: {}', new_lr)
return new_lr
```# Getting Started
## Introduction
This repo is written using Jax and Haiku, and tested using Google Colab TPU. On the
German-English dataset, consisting of 4.5 million sentence pairs, the `base model`
trains to 100k steps in about 16 hours. It achieves a Bleu score 25.5 and PPL 4.95
at 100k training steps, very similar to the reported values of 25.8 and 4.92 for the
same model and training stage.I tried to stay as close as possible to the original architecture. However, there is
one major change which is that I used Pre-LN instead of the original Post-LN. I
implemented everything from scratch, including the data packing, entire model, beam
search, incremental inference using kv-cache. However, I adapted the Blue score
calculation function from the original
[tensor2tensor](https://github.com/tensorflow/tensor2tensor) repo, so as to be sure I
was using the same metric. A companion blog article,
[transformer-from-scratch](https://mlcrumbs.com/transformer-from-scratch) documents
many details of the code design and various problems.## Installation
Install the package with:
pip install git+https://github.com/hrbigelow/transformer-aiayn.git
## Training the BPE Tokenizer
Train a byte-pair encoded (BPE) tokenizer on English-German sentence pair dataset.
The first time this is launched, the dataset will be downloaded to `DOWNLOAD_DIR`.
Subsequent times will use the cached data stored there.You must choose the desired vocabulary size. Other datasets can be found with
`tfds.list_builders()`.NOTE: It is ultimately much faster to run this and the next command locally rather
than using the combination of Colab and Google Cloud Storage. Once the dataset is
downloaded locally, training the tokenizer takes about 2 minutes. Whereas, my
attempt to train it on Colab using the dataset downloaded to a GCS bucket ran for
over 40 minutes without finishing. Also, HuggingFace's progress meter doesn't work
in Colab, but it works when running locally.```bash
# python aiayn/preprocess.py DOWNLOAD_DIR DATASET_NAME VOCAB_SIZE OUT_FILE
python aiayn/preprocess.py train_tokenizer \
~/tensorflow_datasets \
huggingface:wmt14/de-en \
36500 \
de-en-bpe.36500.json
```## Tokenizing the dataset
Now that you have a trained tokenizer (the `tokenizer_file`), use it to convert the
text-based sentence-pair dataset into token sequences (integer arrays) and save them
to tf.record files. It is considered best practice to save this dataset in multiple
shards. This somehow simplifies the process of parallelized reads during training.
Since TPUs are so fast, it is actually somewhat common for the bottleneck to be data
loading speed.Perform this step for both the `train` and `validation` splits of the dataset.
```bash
# python aiayn/preprocess.py tokenize_dataset DOWNLOAD_DIR DATASET_NAME SPLIT \
# TOKENIZER_FILE NPROC NUM_SHARDS OUT_TEMPLATE INPUT_LANG TARGET_LANG
python aiayn/preprocess.py tokenize_dataset \
~/tensorflow_datasets huggingface:wmt14/de-en train de-en-bpe.36500.json \
8 ~/de-en-train/{}.tfrecord en de
```Once finished, upload the `.tfrecord` files to Google Cloud Storage using `gcloud
storage cp` command. You will need to set things up with Google Cloud.NOTE: I have also tried using gdrive for the persistence. It is possible to mount
it into a Colab. However, it is not as reliable or performant and I do not recommend
it, even though the initial setup with GCS takes some work.## Train the model
Train the Encoder-Decoder model (design based on Attention Is All You Need) on the
data. The settings below work for a TPU.[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hrbigelow/transformer-aiayn/blob/master/notebooks/jax_aiayn.ipynb)
```bash
python3 aiayn/train.py \
--dataset_glob 'de-en-train/*.tfrecord' \
--val_dataset_glob 'de-en-val/*.tfrecord' \
--batch_dim0 96 \ # Total number of sentence-pairs per SGD batch
--accum_steps 2 \ # Number of gradient accumulation steps
--ckpt_every 3000 \
--eval_every 100 \ # Compute scores on validation dataset every __ steps
--val_loop_elem 32 \ # Process this many validation sentence pairs per loop
--ckpt_dir ~/checkpoints \
--resume_ckpt None \ # Supply an integer here to resume from a saved checkpoint
--report_every 10 \ # Interval for print metrics to stdout
--max_source_len 320 \ # maximum length in tokens for source sentences
--max_target_len 320 \ # maximum length in tokens for target sentences
--swap_source_target True \ # If true, swap the source and target sentences
--shuffle_size 100000 \ # Buffer size for randomizing data element order
--label_smooth_eps 0.1 \ # Factor for mixing in a uniform distribution to labels
--tokenizer_file de-en.bpe.36500.json
# --streamvis_run_name test \ # Scoping name for visualizing data from different runs
# --streamvis_path svlog \ # Log file for logging visualization data
# --streamvis_buffer_items 100 # How many logging data points to visualize
```## Run the model
Run the trained model on some input sentences, in this case `newstest2013.en` and
write the translation results to `results.out`. It is crucial that
`pos_encoding_factor` is set to the same value as was used for training. (In fact,
it should instead be saved as part of the model checkpoint).```bash
python3 aiayn/sample.py sample \
--ckpt_dir checkpoints \
--resume_ckpt 85000 \
--tokenizer_file de-en.bpe.36500.json \
--batch_file newstest2013.en \
--out_file results.out \
--batch_dim0 64 \
--max_source_len 150 \
--max_target_len 150 \
--random_seed 12345 \
--beam_size 4 \
--beam_search_beta 0.0 \
--pos_encoding_factor 1.0
```## Evaluate the results
```bash
python3 aiayn/sample.py evaluate newstest2013.de results.out
```## Notes
`batch_dim0` is so-named in order to emphasize that it is just one dimension of the
the actual batch, which are individual tokens from each target sentence. The
original 'base model' paper trained with ~25,000 target tokens per batch. With a
`max_target_len` of 320 and 96 sentences in a batch, this leaves a maximum room of
30,720 target tokens. However, tokens are packed, and an average occupancy of
the packed data is around 25,000.When training on a TPU for technical reasons, the quantity `batch_dim0 / accum_steps`
must be evenly divisible by the number of cores (8 on a TPU). In the example above,
that quantity is 48. This means that each core handles a batch of 6 sentence pairs
for each of two gradient accumulation steps.Memory consumption of attention modules is N^2 with context. With a TPU v3 memory, a
context of 320 fits well and is sufficient to cover quite long sentences in the
training set.The validation dataset provided in `val_dataset_glob` is loaded entirely into memory,
so it is expected to be a few thousand examples. During training, every `eval_every`
SGD steps, this entire set is evaluated twice - once with a 'training mode' model
(the same model that is training using SGD) and the second is the 'testing mode'
model - one without dropout enabled. The validation dataset is evaluated in batches of
`val_loop_elem`, which should just be set to as high a number as can fit in the
device memory.