bert中文文本分类微调代码示例
时间: 2024-06-26 22:00:34 浏览: 263
BERT(Bidirectional Encoder Representations from Transformers)是一种预训练语言模型,常用于各种自然语言处理任务,包括文本分类。在中文文本分类中,微调BERT通常涉及以下几个步骤:
```python
# 导入所需库
from transformers import BertTokenizer, BertForSequenceClassification
import torch
# 加载预训练的BERT模型和tokenizer
model_name = "bert-base-chinese" # 使用中文预训练模型
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
# 假设我们已经有了数据集,包含'inputs' (分词后的文本) 和 'labels' (类别标签)
train_inputs = ... # 输入序列编码
train_labels = ... # 类别标签
# 将数据转换为模型可以接受的格式
inputs = tokenizer(train_inputs, padding=True, truncation=True, max_length=512, return_tensors="pt")
labels = torch.tensor(train_labels)
# 定义优化器和损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()
# 微调循环
for epoch in range(num_epochs):
model.train()
outputs = model(**inputs, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 在测试阶段,仅前向传播并获取预测
model.eval()
test_inputs = ... # 测试集输入序列
test_encoded = tokenizer(test_inputs, padding=True, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
predictions = torch.argmax(model(**test_encoded).logits, dim=1)
# 显示或保存结果
print(predictions)
```
阅读全文