Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/liyibo/text-classification-demos
Neural models for Text Classification in Tensorflow, such as cnn, dpcnn, fasttext, bert ...
https://github.com/liyibo/text-classification-demos
bert cnn fasttext tensorflow text-classification
Last synced: about 1 month ago
JSON representation
Neural models for Text Classification in Tensorflow, such as cnn, dpcnn, fasttext, bert ...
- Host: GitHub
- URL: https://github.com/liyibo/text-classification-demos
- Owner: liyibo
- Created: 2019-01-03T08:07:43.000Z (almost 6 years ago)
- Default Branch: master
- Last Pushed: 2019-03-25T05:15:47.000Z (over 5 years ago)
- Last Synced: 2024-08-02T08:09:55.001Z (4 months ago)
- Topics: bert, cnn, fasttext, tensorflow, text-classification
- Language: Python
- Homepage:
- Size: 928 KB
- Stars: 191
- Watchers: 7
- Forks: 44
- Open Issues: 5
-
Metadata Files:
- Readme: readme.md
Awesome Lists containing this project
- awesome-bert - liyibo/text-classification-demos
README
# Text classification demos
Tensorflow 环境下,不同的神经网络模型对中文文本进行分类,本文中的 demo 都是字符级别的文本分类(增加了word-based 的统计结果),简化了文本分类的流程,字符级别的分类在有些任务上的效果可能不好,需要结合实际情况添加自定义的分词模块。
## 数据集
下载地址: https://pan.baidu.com/s/1hugrfRu 密码: qfud
使用 THUCNews 的一个子集进行训练与测试,使用了其中的 10 个分类,每个分类 6500 条数据。
类别如下:
体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐
数据集划分如下:
训练集: 5000 \* 10
验证集: 500 \* 10
测试集: 1000 \* 10具体介绍请参考:[text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn)
## 分类效果
- char-based
| model |fasttext | cnn | rnn | rcnn | han | dpcnn | bert |
|:----- | :-----: | :-----: | :-----: | :-----: | :----- | :-----: | :-----: |
| val_acc | 92.92 | 93.56 | 93.56 | 94.36 | 93.94 | 93.70 | 97.84 |
| test_acc | 93.15 | 94.57 | 94.37 | 95.53 | 93.65 | 94.87 | 96.93 |- word-based
| model |fasttext | cnn | rnn | rcnn | han | dpcnn | bert |
|:----- | :-----: | :-----: | :-----: | :-----: | :----- | :-----: | :-----: |
| val_acc | 95.52 | 95.28 | 93.10 | 95.60 | 95.10 | 95.68 | - |
| test_acc | 95.34 | 95.77 | 94.05 | 96.36 | 95.66 | 95.97 | - |## 模型介绍
### 1、FastText
fasttext_model.py 文件为训练和测试 fasttext 模型的代码
![图1 FastText 模型结构图](images/fasttext.jpg?raw=true)
本代码简化了 fasttext 模型的结构,模型结构非常简单,运行速度简直飞快,模型准确率也不错,可根据实际需要优化模型结构
### 2、TextCNN
cnn_model.py 文件为训练和测试 TextCNN 模型的代码
![图2 TextCNN 模型结构图](images/textcnn.jpg?raw=true)
本代码实现了 TextCNN 模型的结构,通过 3 个不同大小的卷积核,对输入文本进一维卷积,分别 pooling 三个卷积之后的 feature, 拼接到一起,然后进行 dense 操作,最终输出模型结果。可实现速度和精度之间较好的折中。
### 3、RNN
rnn_model.py 文件为训练和测试 TextCNN 模型的代码
![图8 TextRNN 模型结构图](images/textrnn.jpg?raw=true)
本代码实现了 TextRNN 模型的结构,对输入序列进行embedding,然后输入两层的 rnn_cell中学习序列特征,取最后一个 word 的 state 作为进行后续的 fc 操作,最终输出模型结果。
### 4、RCNN
rcnn_model.py 文件为训练和测试 RCNN 模型的代码
![图3 RCNN 模型结构图](images/rcnn.jpg?raw=true)
[Recurrent Convolutional Neural Network for Text Classification](https://scholar.google.com.hk/scholar?q=Recurrent+Convolutional+Neural+Networks+for+Text+Classification&hl=zhCN&as_sdt=0&as_vis=1&oi=scholart&sa=X&ved=0ahUKEwjpx82cvqTUAhWHspQKHUbDBDYQgQMIITAA), 在学习 word representations 时候,同时采用了 rnn 结构来学习 word 的上下文,虽然模型名称为 RCNN,但并没有显式的存在卷积操作。
1、采用双向lstm学习 word 的上下文
```
c_left = tf.concat([tf.zeros(shape), output_fw[:, :-1]], axis=1, name="context_left")
c_right = tf.concat([output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right")
word_representation = tf.concat([c_left, embedding_inputs, c_right], axis=2, name="last")
```
2、pooling + softmaxword_representation 的维度是 batch_size \* seq_length \* 2 \* context_dim + embedding_dim
在 seq_length 维度进行 max pooling,然后进行 fc 操作就可以进行分类了,可以将该网络看成是 fasttext 的改进版本
### 5、HAN
han_model.py 文件为训练和测试 HAN 模型的代码
![图4 HAN 模型结构图](images/han.jpg?raw=true)
HAN 为 Hierarchical Attention Networks,将待分类文本,分为一定数量的句子,分别在 word level 和 sentence level 进行 encoder 和 attention 操作,从而实现对较长文本的分类。
本文是按照句子长度将文本分句的,实际操作中可按照标点符号等进行分句,理论上效果能好一点。
- 1、对文本进行分句
对每个句子进行双向lstm编码
batch_size = 64, seq_length = 600,
sent_num = 10, emb_size = 128,
lstm_hid_dim = 256数据维度变化:64 \* 600 \* 128 --- (64\*10) \* 60 \* 128 --- (64\*10) \* 60 \* 512
- 2、word level attention
![图4 attention](images/han_2.jpg?raw=true)
(1) 将输入的lstm编码结果做一次非线性变换,可以看做是输入编码的hidden representation, shape = (64\*10) \* 60 \* 256
(2) 将 hidden representation 与一个学习得到的 word level context vector 的相似性进行 softmax,得到每个单词在句子中的权重
(3) 对输入的lstm 编码进行加权求和,得到句子的向量表示
数据维度变化:(64\*10) \* 60 \* 512 --- (64\*10) \* 512
- 3、得到每个句子的向量表示
- 4、sentence level attention
与 word level attention 过程一样,只是该层是句子级别的attention
数据维度变化:64 \* 10 \* 512 --- 64 \* 512
- 5、得到 document 的向量表示
- 6、dence + softmax
### 6、DPCNN
dpcnn_model.py 文件为训练和测试 DPCNN 模型的代码
![图5 DPCNN 模型结构图](images/dpcnn.jpg?raw=true)
DPCNN 通过卷积和残差连接增加了以往用于文本分类 CNN 网络的深度,可以有效提取文本中的远程关系特征,并且复杂度不高,实验表名,效果比以往的 CNN 结构要好一点。
- region_embedding: word_embedding 之后进行的 ngram 卷积结果
### 7、BERT
bert_model.py 文件为训练和测试 BERT 模型的代码
google官方提供用于文本分类的demo写的比较抽象,所以本文基于 google 提供的代码和初始化模型,重写了文本分类模型的训练和测试代码,bert 分类模型在小数据集下效果很好,通过较少的迭代次数就能得到很好的效果,但是训练和测试速度较慢,这点不如基于 CNN 的网络结构。
bert_model.py 将训练数据和验证数据存储为 tfrecord 文件,然后进行训练
由于 bert 提供的预训练模型较大,需要自己去 [google-research/bert](https://github.com/google-research/bert) 中下载预训练好的模型,本实验采用的是 "BERT-Base, Chinese" 模型。
![图6 BERT 输入数据格式](images/bert_1.jpeg?raw=true)
![图7 BERT 下游任务介绍](images/bert_2.jpeg?raw=true)
## 参考
- 1 [text-classification-cnn-rnn](https://github.com/gaussic/text-classification-cnn-rnn)
- 2 [text_classification](https://github.com/brightmart/text_classification)
- 3 [bert](https://github.com/google-research/bert)