用bert二分类模型实现情感四分类的代码
时间: 2024-04-29 15:21:32 浏览: 5
这里给出一个基于BERT的情感四分类模型的代码示例:
```python
import torch
import torch.nn as nn
import transformers
class SentimentClassifier(nn.Module):
def __init__(self, pretrained_model_name='bert-base-uncased', num_classes=4):
super().__init__()
self.bert = transformers.BertModel.from_pretrained(pretrained_model_name)
self.dropout = nn.Dropout(p=0.2)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
```
上述代码中,我们首先定义了一个`SentimentClassifier`类,继承自`nn.Module`类。在类的构造函数中,我们首先调用`transformers.BertModel.from_pretrained()`方法加载预训练的BERT模型。
在`forward()`方法中,我们将输入的`input_ids`和`attention_mask`传入BERT模型,得到模型的输出结果。然后,我们通过一个线性层将BERT模型的输出映射到情感分类的标签空间,即四个类别。最后,我们将分类器的输出结果`logits`返回。
接下来,我们可以使用该模型进行训练和预测。在训练过程中,我们需要定义损失函数和优化器,并迭代若干次更新模型参数。在预测过程中,我们可以将模型的输出结果通过softmax函数转换为概率分布,然后选择概率最大的类别作为预测结果。