https://github.com/cosmoquester/transformers-bart-pretrain
Script to pre-train hugginface transformers BART with Tensorflow 2
https://github.com/cosmoquester/transformers-bart-pretrain
bart gpu huggingface-transformers pretraining tensorflow tpu
Last synced: 5 months ago
JSON representation
Script to pre-train hugginface transformers BART with Tensorflow 2
- Host: GitHub
- URL: https://github.com/cosmoquester/transformers-bart-pretrain
- Owner: cosmoquester
- License: mit
- Created: 2021-06-18T09:31:37.000Z (about 5 years ago)
- Default Branch: master
- Last Pushed: 2023-04-13T05:14:44.000Z (about 3 years ago)
- Last Synced: 2025-09-05T01:42:06.833Z (10 months ago)
- Topics: bart, gpu, huggingface-transformers, pretraining, tensorflow, tpu
- Language: Python
- Homepage:
- Size: 1.43 MB
- Stars: 33
- Watchers: 1
- Forks: 6
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
- Citation: CITATION.cff
Awesome Lists containing this project
README
# transformers TF BART pre-training
[](https://github.com/psf/black)
[](https://pycqa.github.io/isort/)
[](https://app.circleci.com/pipelines/github/cosmoquester/transformers-bart-pretrain)
[](https://codecov.io/gh/cosmoquester/transformers-bart-pretrain)
- Script to pre-train hugginface transformers BART
- Training [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461)
- `Text infilling` and `Sentence Permutation` functions are available now
# Train
You can train huggingface transformers model simply like below example.
(below example works without change as itself using sample data)
```sh
$ CUDA_VISIBLE_DEVICES=1 python -m scripts.train \
--model-config-path configs/base.json \
--train-dataset-path tests/data/sample1.txt \
--dev-dataset-path tests/data/sample1.txt \
--sp-model-path sp_model/sp_model_unigram_8K.model \
--device GPU \
--auto-encoding \
--batch-size 2 \
--steps-per-epoch 100 \
--mask-token "[MASK]" \
--mixed-precision
```
## Arguments
```sh
File Paths:
--model-config-path MODEL_CONFIG_PATH
model config file
--train-dataset-path TRAIN_DATASET_PATH
training dataset, a text file or multiple files ex)
*.txt
--dev-dataset-path DEV_DATASET_PATH
dev dataset, a text file or multiple files ex) *.txt
--pretrained-checkpoint PRETRAINED_CHECKPOINT
pretrained checkpoint path
--output-path OUTPUT_PATH
output directory to save log and model checkpoints
--sp-model-path SP_MODEL_PATH
sentencepiece model path to tokenizer
Training Parameters:
--mask-token MASK_TOKEN
mask token ex) [MASK]
--mask-token-id MASK_TOKEN_ID
mask token id of vocab
--epochs EPOCHS
--steps-per-epoch STEPS_PER_EPOCH
--learning-rate LEARNING_RATE
--min-learning-rate MIN_LEARNING_RATE
--warmup-steps WARMUP_STEPS
--warmup-rate WARMUP_RATE
--batch-size BATCH_SIZE
total training batch size of all devices
--dev-batch-size DEV_BATCH_SIZE
--num-total-dataset NUM_TOTAL_DATASET
--shuffle-buffer-size SHUFFLE_BUFFER_SIZE
--prefetch-buffer-size PREFETCH_BUFFER_SIZE
--max-sequence-length MAX_SEQUENCE_LENGTH
--weight-decay WEIGHT_DECAY
use weight decay
--clipnorm CLIPNORM clips gradients to a maximum norm.
--disable-text-infilling
disable input noising
--disable-sentence-permutation
disable input noising
--masking-rate MASKING_RATE
text infilling masking rate
--permutation-segment-token-id PERMUTATION_SEGMENT_TOKEN_ID
segment token id for sentence permutation
Other settings:
--tensorboard-update-freq TENSORBOARD_UPDATE_FREQ
log losses and metrics every after this value step
--mixed-precision Use mixed precision FP16
--auto-encoding train by auto encoding with text lines dataset
--use-tfrecord train using tfrecord dataset
--repeat-each-file repeat each dataset and uniform sample for train
example
--debug-nan-loss Trainin with this flag, print the number of Nan loss
(not supported on TPU)
--seed SEED random seed
--skip-epochs SKIP_EPOCHS
skip this number of epochs
--device {CPU,GPU,TPU}
device to train model
--max-over-sequence-policy {filter,slice}
Policy for sequences of which length is over the max
```
- `model-config-path` is huggingface bart model config file path.
- `pretrained-checkpoint` is trained model checkpoint path.
- `sp-model-path` is sentencepiece tokenizer model path.
- with `repeat-each-file` flag, you can repeat each dataset files forever even if one of dataset were run out.