微调transformers BERT文本分类预训练模型


微调transformers BERT文本分类预训练模型

本节我们为Hugging Face的transformers包中提供的文本分类预训练模型做一个Fine-tune,即微调,令其更适应于我们使用的数据集,并做一个分类任务。

主要参考了transformers官方文档中的fine-tune一节,使用的数据集是Hugging Face提供的封装好的数据集。

本文使用PyTorch深度学习框架,在GTX 1080ti显卡上运行。

数据集载入与tokenize

首先使用同样是Hugging Face提供的datasets包载入数据集IMDB,这是一个影评数据集:

from datasets import load_dataset

raw_datasets = load_dataset("imdb")

该数据集已经划分好了train,test, unsupervised集,直接使用即可。

然后使用预训练模型进行分词与向量化tokenize:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

这里的tokenizer可以把文本映射为512维的向量。然后把raw_datasets中的所有文本进行一下映射:

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

Fine-tune

先从数据集中筛选一个较小的部分:

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

然后定义我们的文本分类模型:

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

然后就可以进行微调啦:

from transformers import TrainingArguments
from transformers import Trainer

training_args = TrainingArguments("test_trainer")
trainer = Trainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset
)

trainer.train()

训练完毕后是验证结果:

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)
trainer.evaluate()

我在本地跑下来的结果是:

'eval_loss': 0.5007675886154175,
 'eval_accuracy': 0.887,
 'eval_runtime': 18.8074,
 'eval_samples_per_second': 53.171,
 'eval_steps_per_second': 6.646

分类准确率为88.7%.

真的很棒呢,不用花费大量时间训练就可以得到这样的效果。后面进一步研究,应该还会有提高😆😆😆

一起加油吧,小仙女~🤜🤛


评论
  目录