使用Huggingface中预训练的BERT模型进行文本分类。
本文使用的是RoBERTa-wwm-ext,模型导入方式参见https://github.com/ymcui/Chinese-BERT-wwm。由于做了全词遮罩(Whole Word Masking),效果相较于裸的BERT会有所提升。
数据集使用THUCNews中的train.txt:https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch/tree/master/THUCNews/data,十分类问题,示例如下:
|
|
代码如下:
|
|
在2080Ti上,train一个epoch差不多三分钟,train一个epoch后,准确率已经有94%以上了。
_, pooled = self.bert(context, token_type_ids=types, attention_mask=mask)
这行代码中有几个需要注意的点:
context
形如:[101, …, 102, 0, 0, 0, …, 0]token_type_ids
形如:[0, 0, 0, …, 1, 1, 1, …, 1]attention_mask
形如:[1, 1, 1, …, 0, 0, 0, …, 0]- 函数返回的两个结果size分别为[batch_size, max_seq_len, hidden_size=768]和[batch_size, hiddensize=768],前者是最后一层所有的hidden向量,后者是CLS的hidden向量经过一层dense和activation后得到的,所以特别注意:[:, 0, :]和pooled[:, :]是不一样的。这部分源码如下:
|
|
References: