微调transformers BERT文本分类预训练模型
本节我们为Hugging Face的transformers包中提供的文本分类预训练模型做一个Fine-tune,即微调,令其更适应于我们使用的数据集,并做一个分类任务。
主要参考了transformers
官方文档中的fine-tune一节,使用的数据集是Hugging Face提供的封装好的数据集。
本文使用PyTorch
深度学习框架,在GTX 1080ti
显卡上运行。
数据集载入与tokenize
首先使用同样是Hugging Face提供的datasets包载入数据集IMDB,这是一个影评数据集:
from datasets import load_dataset
raw_datasets = load_dataset("imdb")
该数据集已经划分好了train
,test
, unsupervised
集,直接使用即可。
然后使用预训练模型进行分词与向量化tokenize:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
这里的tokenizer
可以把文本映射为512维的向量。然后把raw_datasets
中的所有文本进行一下映射:
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
Fine-tune
先从数据集中筛选一个较小的部分:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]
然后定义我们的文本分类模型:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
然后就可以进行微调啦:
from transformers import TrainingArguments
from transformers import Trainer
training_args = TrainingArguments("test_trainer")
trainer = Trainer(
model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset
)
trainer.train()
训练完毕后是验证结果:
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=small_train_dataset,
eval_dataset=small_eval_dataset,
compute_metrics=compute_metrics,
)
trainer.evaluate()
我在本地跑下来的结果是:
'eval_loss': 0.5007675886154175,
'eval_accuracy': 0.887,
'eval_runtime': 18.8074,
'eval_samples_per_second': 53.171,
'eval_steps_per_second': 6.646
分类准确率为88.7%.
真的很棒呢,不用花费大量时间训练就可以得到这样的效果。后面进一步研究,应该还会有提高😆😆😆
一起加油吧,小仙女~🤜🤛