用huggingface trainer在yelp review数据集上通过Bert模型进行微调的代码
时间: 2024-05-02 12:19:36 浏览: 159
以下是通过 Hugging Face Trainer 在 Yelp 数据集上微调 BERT 模型的代码:
``` python
from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
import torch
from datasets import load_dataset
# 加载 Yelp 数据集
dataset = load_dataset('yelp_polarity', split='train')
# 初始化 BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# 定义数据预处理函数
def preprocess_function(examples):
return tokenizer(examples['text'], truncation=True, padding=True)
# 对数据集进行预处理
train_dataset = dataset.map(preprocess_function, batched=True)
# 初始化 BERT 分类模型
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# 定义训练参数
training_args = TrainingArguments(
output_dir='./results', # 模型输出文件夹
num_train_epochs=5, # 总训练轮数
per_device_train_batch_size=16, # 每个 GPU 训练时的 batch size
save_total_limit=1, # 最多保存的 checkpoint 数量
learning_rate=2e-5, # AdamW 优化器的初始学习率
)
# 定义 Trainer 对象
trainer = Trainer(
model=model, # 模型
args=training_args, # 训练参数
train_dataset=train_dataset, # 训练数据集
)
# 进行微调
trainer.train()
```
希望这个代码能够帮到你!
阅读全文