请给出使用bert模型进行文本分类的代码
时间: 2023-05-30 07:06:29 浏览: 127
这里提供一个使用bert模型进行文本分类的代码示例,该代码使用了pytorch和transformers库:
```python
import torch
import torch.nn as nn
from transformers import BertModel
class BertClassifier(nn.Module):
def __init__(self, num_classes):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(p=0.1)
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
self.softmax = nn.Softmax(dim=1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.fc(pooled_output)
probs = self.softmax(logits)
return probs
```
使用上述代码可以实例化一个BERT分类器模型,其中`num_classes`为分类器需要预测的类别数目。在模型训练时,可以使用该模型对输入文本进行编码,然后将编码结果送入全连接层,最后输出每个类别的概率分布。
阅读全文