使用bert做文本分类的例子
时间: 2023-09-01 19:06:35 浏览: 193
下面是一个使用BERT进行文本分类的简单示例:
```python
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# 加载预训练的BERT模型和tokenizer
model_name = 'bert-base-uncased' # 或者其他可用的预训练模型
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2) # 假设二分类任务
# 输入文本
text = "这是一段待分类的文本"
# 对文本进行编码
encoded_input = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 将编码后的输入传入模型进行推理
outputs = model(**encoded_input)
# 获取模型的输出结果
logits = outputs.logits # 输出的logits是一个包含两个值的张量,对应于两个类别的分数
predicted_labels = torch.argmax(logits, dim=1) # 预测的类别标签
# 输出预测结果
labels = ['类别A', '类别B'] # 类别标签列表
predicted_label = labels[predicted_labels.item()]
print("预测的类别为:", predicted_label)
```
在这个例子中,我们首先使用`BertTokenizer`加载预训练的BERT模型的tokenizer。然后,我们使用`BertForSequenceClassification`加载预训练的BERT模型,并指定了分类任务的类别数。接下来,我们将待分类的文本进行编码,使用`tokenizer.encode_plus`方法对文本进行编码,将其转换为模型可接受的输入格式。
然后,我们将编码后的输入传入BERT模型进行推理,得到模型的输出结果。输出结果中的logits是一个包含两个值的张量,对应于两个类别的分数。我们可以使用`torch.argmax`方法获取预测的类别标签。最后,我们根据类别标签列表,输出预测的类别结果。
请注意,这只是一个简单的示例,实际应用中可能需要根据具体任务进行适当的调整和扩展。
阅读全文