Bert实现文本分类的代码
时间: 2023-05-30 10:04:49 浏览: 226
以下是使用Bert进行文本分类的示例代码:
```python
import torch
from transformers import BertTokenizer, BertModel
# 载入Bert模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
# 定义分类器模型
class BertClassifier(torch.nn.Module):
def __init__(self, bert_model, num_classes):
super().__init__()
self.bert_model = bert_model
self.dropout = torch.nn.Dropout(0.1)
self.classifier = torch.nn.Linear(bert_model.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = bert_output.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
# 实例化分类器模型
num_classes = 2 # 分类数
bert_classifier = BertClassifier(bert_model, num_classes)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(bert_classifier.parameters(), lr=2e-5)
# 训练模型
for epoch in range(num_epochs):
for input_ids, attention_mask, labels in train_loader:
# 将数据移动到GPU上
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
# 计算模型输出
logits = bert_classifier(input_ids, attention_mask)
# 计算损失
loss = criterion(logits, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 在验证集上评估模型
with torch.no_grad():
correct = 0
total = 0
for input_ids, attention_mask, labels in val_loader:
# 将数据移动到GPU上
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
# 计算模型输出
logits = bert_classifier(input_ids, attention_mask)
# 计算预测结果
predictions = torch.argmax(logits, dim=1)
# 更新正确预测的数量
total += labels.size(0)
correct += (predictions == labels).sum().item()
accuracy = correct / total
print(f'Epoch {epoch + 1}, Validation Accuracy: {accuracy:.3f}')
```
在这个示例中,我们使用了Bert模型进行文本分类。我们首先通过`BertTokenizer`载入预训练的Bert tokenizer,然后通过`BertModel`载入预训练的Bert模型。接着,我们定义了一个简单的分类器模型`BertClassifier`,该模型使用Bert模型的最后一个隐藏层的输出来进行分类。我们使用交叉熵损失函数和Adam优化器进行训练。在每个epoch结束时,我们使用验证集评估模型的性能。
阅读全文