来一段基于bert预训练文本分类的微调的代码
时间: 2023-10-17 16:05:13 浏览: 123
好的,这里是一段基于BERT预训练模型进行微调的文本分类代码,使用PyTorch实现:
```python
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification
# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# 定义分类器,输出层为2个类别
classifier = nn.Linear(model.config.hidden_size, 2)
# 将BERT模型和分类器拼接成一个完整的模型
model.classifier = classifier
# 加载数据
train_dataset = ...
val_dataset = ...
# 定义训练器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
# 开始训练
for epoch in range(num_epochs):
# 训练
model.train()
for inputs, labels in train_dataset:
optimizer.zero_grad()
outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=labels)
loss = criterion(outputs.logits, labels)
loss.backward()
optimizer.step()
# 验证
model.eval()
with torch.no_grad():
total_loss = 0.0
total_correct = 0
for inputs, labels in val_dataset:
outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=labels)
total_loss += criterion(outputs.logits, labels).item()
total_correct += (outputs.logits.argmax(-1) == labels).sum().item()
val_loss = total_loss / len(val_dataset)
val_acc = total_correct / len(val_dataset)
print(f"Epoch {epoch}: Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
```
在这段代码中,我们首先加载了预训练的BERT模型和分词器,然后定义了一个分类器,将其与BERT模型拼接在一起,得到一个完整的分类模型。接着加载了训练和验证数据,并定义了训练器。最后,进行了训练和验证,并输出了验证损失和准确率。
阅读全文