给我一段bert微调阶段的代码示例
时间: 2024-04-01 14:06:48 浏览: 72
下面是一个使用Hugging Face的Transformers库进行BERT微调的代码示例:
```python
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW
# 加载预训练的BERT模型和tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# 准备数据集
train_texts = ['Example sentence 1', 'Example sentence 2', ...]
train_labels = [0, 1, ...]
# 对文本进行分词和编码
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
# 创建数据集对象
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = CustomDataset(train_encodings, train_labels)
# 定义训练参数
batch_size = 16
epochs = 3
learning_rate = 2e-5
# 创建数据加载器
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
# 定义优化器和损失函数
optimizer = AdamW(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()
# 开始微调
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.train()
for epoch in range(epochs):
total_loss = 0
for batch in train_dataloader:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
average_loss = total_loss / len(train_dataloader)
print(f'Epoch {epoch + 1}: Average Loss = {average_loss}')
# 保存微调后的模型
model.save_pretrained('path/to/save/model')
tokenizer.save_pretrained('path/to/save/tokenizer')
```
请注意,此代码仅为示例,实际使用时可能需要根据您的数据集和任务进行适当的调整和修改。另外,此示例仅展示了训练过程,如果需要在验证集或测试集上进行评估,请自行添加相应的代码。
阅读全文