BERT文本分类任务Python代码
时间: 2023-06-10 08:03:41 浏览: 146
好的,请见下面的代码:
``` python
import torch
import torch.nn as nn
from transformers import BertModel
class BertClassifier(nn.Module):
def __init__(self, freeze_bert=False):
super(BertClassifier, self).__init__()
# Loading pre-trained BERT model
self.bert = BertModel.from_pretrained('bert-base-uncased')
# Freeze bert layers
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
# Classification layer
self.fc = nn.Linear(768, 2)
def forward(self, input_ids, attention_mask):
# Feed input to BERT
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
# Extract last hidden state of the token `[CLS]` for classification task
last_hidden_state_cls = outputs[0][:, 0, :]
logits = self.fc(last_hidden_state_cls)
return logits
```
这是一个使用PyTorch和Transformers库中的BERT模型进行文本分类的代码。在构建模型时,我们首先初始化一个BertClassifier类,其构造函数中包含一个BERT模型、一个是否冻结BERT层的参数和一个全连接层(用于分类)。
在前向传递中,我们将输入传递给BERT,然后提取经过BERT处理的输入中的 `[CLS]` 标志处的最后一个隐藏状态。我们使用这个隐藏状态来进行文本分类。输出是一个包含两个元素的张量,用于二元分类任务。
阅读全文