BERT进行文本分类的代码
时间: 2024-09-12 19:04:12 浏览: 52
bert文本分类 代码+数据
BERT(Bidirectional Encoder Representations from Transformers)是一种预训练语言表示的方法,它通过在大量文本数据上使用掩码语言模型(Masked Language Model)和下一个句子预测(Next Sentence Prediction)任务进行预训练,以获得深度双向的文本表示。随后,BERT可以被微调用于各种自然语言处理任务,包括文本分类。
以下是一个使用BERT进行文本分类的基本代码示例,这里使用的是Hugging Face的Transformers库,这是一个非常流行的NLP库,提供了BERT及其他预训练模型的接口。
```python
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练的BERT分词器和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# 准备数据集
# 假设我们有以下的文本和标签
texts = ["Hello, my dog is cute", "I love coffee"]
labels = [1, 0] # 假设0表示负面情感,1表示正面情感
# 使用分词器编码文本,添加特殊标记('[CLS]'和'[SEP]'),并调整长度为最大长度
inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# 将标签转换为张量
labels = torch.tensor(labels)
# 创建数据集和数据加载器
dataset = TensorDataset(input_ids, attention_mask, labels)
train_dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=2)
# 设置优化器和学习率调度器
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
# 训练模型(这里仅为示例,实际训练需要更多的循环和步骤)
model.train()
for epoch in range(num_epochs):
for batch in train_dataloader:
batch_input_ids, batch_attention_mask, batch_labels = tuple(t.to(device) for t in batch)
model.zero_grad()
# 正向传播
outputs = model(batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
# 计算损失并进行反向传播
loss = outputs.loss
loss.backward()
# 更新参数
optimizer.step()
scheduler.step()
# 评估模型等后续步骤...
```
请记住,这只是一个非常基础的代码示例。在实际应用中,你需要进行数据预处理,定义合适的批处理大小,执行完整的训练循环,并包括验证和测试过程。同时,还需要考虑超参数的调整和模型的保存与加载等问题。
阅读全文