An open API service indexing awesome lists of open source software.

https://github.com/ssbuild/bert_pretty

bert text encode or decode
https://github.com/ssbuild/bert_pretty

bert-text

Last synced: about 1 year ago
JSON representation

bert text encode or decode

Awesome Lists containing this project

README

          

### demo

```python

import numpy as np
#FullTokenizer is official and you can use your tokenization .
from bert_pretty import FullTokenizer,\
text_feature, \
text_feature_char_level,\
text_feature_word_level,\
text_feature_char_level_input_ids_mask, \
text_feature_word_level_input_ids_mask, \
text_feature_char_level_input_ids_segment, \
text_feature_word_level_input_ids_segment, \
seqs_padding,rematch

from bert_pretty.ner import load_label_bioes,load_label_bio,load_labels as ner_load_labels
from bert_pretty.ner import ner_crf_decoding,\
ner_pointer_decoding,\
ner_pointer_decoding_with_mapping,\
ner_pointer_double_decoding,ner_pointer_double_decoding_with_mapping

from bert_pretty.cls import cls_softmax_decoding,cls_sigmoid_decoding,load_labels as cls_load_labels

tokenizer = FullTokenizer(vocab_file=r'F:\pretrain\chinese_L-12_H-768_A-12\vocab.txt',do_lower_case=True)
text_list = ["你是谁123aa\ta嘂a","嘂adasd"]

def test():
do_lower_case = tokenizer.basic_tokenizer.do_lower_case
inputs = [tokenizer.tokenize(text) for text in text_list]

mapping = [rematch(text,tokens, do_lower_case) for text,tokens in zip(text_list,inputs)]

inputs = [ tokenizer.convert_tokens_to_ids(input) for input in inputs]

input_mask = [[1] * len(input) for input in inputs]
input_segment = [[0] * len(input) for input in inputs]
input_ids = seqs_padding(inputs)
input_mask = seqs_padding(input_mask)
input_segment = seqs_padding(input_segment)

print('input_ids\n', input_ids)
print('mapping\n',mapping)
print('input_mask\n',input_mask)
print('input_segment\n',input_segment)
print('\n\n')

def test_charlevel():
do_lower_case = tokenizer.basic_tokenizer.do_lower_case
vocab = tokenizer.vocab
unk_id = vocab.get('[UNK]')

inputs = []
for text in text_list:
if do_lower_case:
inputs.append([vocab.get(char,unk_id) for char in text])
else:
inputs.append([vocab.get(char.lower(), unk_id) for char in text])

inputs = [ tokenizer.convert_tokens_to_ids(input) for input in inputs]

input_mask = [[1] * len(input) for input in inputs]
input_segment = [[0] * len(input) for input in inputs]
input_ids = seqs_padding(inputs)
input_mask = seqs_padding(input_mask)
input_segment = seqs_padding(input_segment)

print('input_ids\n', input_ids)
print('input_mask\n',input_mask)
print('input_segment\n',input_segment)
print('\n\n')

# labels = ['标签1','标签2']
# print(cls.load_labels(labels))
#
# print(ner.load_label_bio(labels))

'''
# def ner_crf_decoding(batch_text, id2label, batch_logits, trans=None,batch_mapping=None,with_dict=True):
ner crf decode 解析crf序列 or 解析 已经解析过的crf序列

batch_text input_instance list ,
id2label 标签 list or dict
batch_logits 为bert 预测结果 logits_all (batch,seq_len,num_tags) or (batch,seq_len)
trans 是否启用trans预测 , 2D
batch_mapping 映射序列
'''

'''
def ner_pointer_decoding(batch_text, id2label, batch_logits, threshold=1e-8,coordinates_minus=False,with_dict=True)

batch_text text list ,
id2label 标签 list or dict
batch_logits (batch,num_labels,seq_len,seq_len)
threshold 阈值
coordinates_minus
'''

'''
def ner_pointer_decoding_with_mapping(batch_text, id2label, batch_logits, batch_mapping,threshold=1e-8,coordinates_minus=False,with_dict=True)

batch_text text list ,
id2label 标签 list or dict
batch_logits (batch,num_labels,seq_len,seq_len)
threshold 阈值
coordinates_minus
'''

'''
cls_softmax_decoding(batch_text, id2label, batch_logits,threshold=None)
batch_text 文本list ,
id2label 标签 list or dict
batch_logits (batch,num_classes)
threshold 阈值
'''

'''
cls_sigmoid_decoding(batch_text, id2label, batch_logits,threshold=0.5)

batch_text 文本list ,
id2label 标签 list or dict
batch_logits (batch,num_classes)
threshold 阈值
'''

def test_cls_decode():
num_label =3
np.random.seed(123)
batch_logits = np.random.rand(2,num_label)
result = cls_softmax_decoding(text_list,['标签1','标签2','标签3'],batch_logits,threshold=None)
print(result)

batch_logits = np.random.rand(2,num_label)
print(batch_logits)
result = cls_sigmoid_decoding(text_list,['标签1','标签2','标签3'],batch_logits,threshold=0.5)
print(result)

if __name__ == '__main__':
test()
test_charlevel()

test_cls_decode()

```