Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/xuyige/BERT4doc-Classification
Code and source for paper ``How to Fine-Tune BERT for Text Classification?``
https://github.com/xuyige/BERT4doc-Classification
bert natural-language-processing text-classification
Last synced: about 1 month ago
JSON representation
Code and source for paper ``How to Fine-Tune BERT for Text Classification?``
- Host: GitHub
- URL: https://github.com/xuyige/BERT4doc-Classification
- Owner: xuyige
- License: apache-2.0
- Created: 2020-02-05T10:54:03.000Z (almost 5 years ago)
- Default Branch: master
- Last Pushed: 2021-10-19T06:11:47.000Z (about 3 years ago)
- Last Synced: 2024-11-08T13:19:52.644Z (about 1 month ago)
- Topics: bert, natural-language-processing, text-classification
- Language: Python
- Size: 795 KB
- Stars: 616
- Watchers: 9
- Forks: 99
- Open Issues: 13
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- StarryDivineSky - xuyige/BERT4doc-Classification
README
# How to Fine-Tune BERT for Text Classification?
This is the code and source for the paper [How to Fine-Tune BERT for Text Classification?](https://arxiv.org/abs/1905.05583)
In this paper, we conduct exhaustive experiments to investigate different fine-tuning methods of BERT on text classification task and provide a general solution for BERT fine-tuning.
\*********** **update at Mar 14, 2020** \*************
Our checkpoint can be loaded in BertEmbedding from the latest [fastNLP](https://github.com/fastnlp/fastNLP) package.
[Link to](https://github.com/fastnlp/fastNLP/blob/master/fastNLP/embeddings/bert_embedding.py) fastNLP.embeddings.BertEmbedding
## Requirements
For further pre-training, we borrow some code from Google BERT. Thus, we need:
+ tensorflow==1.1x
+ spacy
+ pandas
+ numpyNote that you need Python 3.7 or earlier for compatibility with tensorflow 1.1x.
For fine-tuning, we borrow some codes from pytorch-pretrained-bert package (now well known as transformers). Thus, we need:
+ torch>=0.4.1,<=1.2.0
## Run the code
### 1) Prepare the data set:
#### Sogou News
We determine the category of the news based on the URL, such as “sports” corresponding
to “http://sports.sohu.com”. We choose 6 categories
– “sports”, “house”, “business”, “entertainment”,
“women” and “technology”. The number
of training samples selected for each class is 9,000
and testing 1,000.Data is available at [here](https://drive.google.com/drive/folders/1Rbi0tnvsQrsHvT_353pMdIbRwDlLhfwM).
#### The rest data sets
The rest data sets were built by [Zhang et al. (2015)](https://papers.nips.cc/paper/5782-character-level-convolutional-networks-for-text-classification.pdf).
We download from [URL](https://drive.google.com/drive/u/0/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M) created by Xiang Zhang.### 2) Prepare Google BERT:
[BERT-Base, Uncased](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip)
[BERT-Base, Chinese](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)
### 3) Further Pre-Training:
#### Generate Further Pre-Training Corpus
Here we use AG's News as example:
```shell
python generate_corpus_agnews.py
```
File ``agnews_corpus_test.txt`` can be found in directory ``./data``.#### Run Further Pre-Training
```shell
python create_pretraining_data.py \
--input_file=./AGnews_corpus.txt \
--output_file=tmp/tf_AGnews.tfrecord \
--vocab_file=./uncased_L-12_H-768_A-12/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
python run_pretraining.py \
--input_file=./tmp/tf_AGnews.tfrecord \
--output_dir=./uncased_L-12_H-768_A-12_AGnews_pretrain \
--do_train=True \
--do_eval=True \
--bert_config_file=./uncased_L-12_H-768_A-12/bert_config.json \
--init_checkpoint=./uncased_L-12_H-768_A-12/bert_model.ckpt \
--train_batch_size=32 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--num_train_steps=100000 \
--num_warmup_steps=10000 \
--save_checkpoints_steps=10000 \
--learning_rate=5e-5
```### 4) Fine-Tuning
#### Convert Tensorflow checkpoint to PyTorch checkpoint
```shell
python convert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/model.ckpt-100000 \
--bert_config_file ./uncased_L-12_H-768_A-12_AGnews_pretrain/bert_config.json \
--pytorch_dump_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/pytorch_model.bin
```#### Fine-Tuning on downstream tasks
While fine-tuning on downstream tasks, we notice that different GPU (e.g.: 1080Ti and Titan Xp) may cause
slight differences in experimental results even though we fix the initial random seed.
Here we use 1080Ti * 4 as example.Take Exp-I (See Section 5.3) as example,
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python run_classifier_single_layer.py \
--task_name imdb \
--do_train \
--do_eval \
--do_lower_case \
--data_dir ./IMDB_data/ \
--vocab_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/vocab.txt \
--bert_config_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/bert_config.json \
--init_checkpoint ./uncased_L-12_H-768_A-12_IMDB_pretrain/pytorch_model.bin \
--max_seq_length 512 \
--train_batch_size 24 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir ./imdb \
--seed 42 \
--layers 11 10 \
--trunc_medium -1
```where ``num_train_epochs`` can be 3.0, 4.0, or 6.0.
``layers`` indicates list of layers which will be taken as feature for classification.
-2 means use pooled output, -1 means concat all layer, the command above means concat
layer-10 and layer-11 (last two layers).``trunc_medium`` indicates dealing with long texts. -2 means head-only, -1 means tail-only,
0 means head-half + tail-half (e.g.: head256+tail256),
other natural number k means head-k + tail-rest (e.g.: head-k + tail-(512-k)).There also other arguments for fine-tuning:
``pooling_type`` indicates which feature will be used for classification. `mean` means
mean-pooling for hidden state of the whole sequence, `max` means max-pooling, default means
taking hidden state of `[CLS]` token as features.``layer_learning_rate`` and ``layer_learning_rate_decay`` in ``run_classifier_discriminative.py``
indicates layer-wise decreasing layer rate (See Section 5.3.4).## Further Pre-Trained Checkpoints
We upload IMDb-based further pre-trained checkpoints at
[here](https://drive.google.com/drive/folders/1Rbi0tnvsQrsHvT_353pMdIbRwDlLhfwM).For other checkpoints, please contact us by e-mail.
## How to cite our paper
```text
@inproceedings{sun2019fine,
title={How to fine-tune {BERT} for text classification?},
author={Sun, Chi and Qiu, Xipeng and Xu, Yige and Huang, Xuanjing},
booktitle={China National Conference on Chinese Computational Linguistics},
pages={194--206},
year={2019},
organization={Springer}
}
```