前面写的BERT+CRF的模型代码的训练代码
时间: 2024-03-20 14:42:02 浏览: 52
好的,下面是BERT+CRF模型的训练代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForTokenClassification
from seqeval.metrics import f1_score
from dataset import MyDataset
from utils import get_tags, get_entities, get_chunks, get_labels
# 设置随机种子,保证实验结果可以重现
seed = 2022
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# 设置训练参数
epochs = 10
batch_size = 32
learning_rate = 2e-5
warmup_ratio = 0.1
max_grad_norm = 1.0
log_interval = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 读取数据集
train_dataset = MyDataset('train.txt', BertTokenizer.from_pretrained('bert-base-chinese'))
dev_dataset = MyDataset('dev.txt', BertTokenizer.from_pretrained('bert-base-chinese'))
# 构建模型
model = BertForTokenClassification.from_pretrained('bert-base-chinese')
model.classifier = nn.Linear(model.classifier.in_features, len(get_labels()))
model.to(device)
# 定义优化器和学习率调度器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
total_steps = len(train_dataset) * epochs // batch_size
warmup_steps = int(warmup_ratio * total_steps)
scheduler = optim.lr_scheduler.LinearLR(optimizer, total_steps - warmup_steps, learning_rate)
# 训练模型
model.zero_grad()
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False)
for epoch in range(1, epochs+1):
model.train()
for batch, (inputs, labels) in enumerate(train_dataloader, 1):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)[0]
loss = model.loss_function(outputs, labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
scheduler.step()
model.zero_grad()
if batch % log_interval == 0:
print(f'Epoch: {epoch}, batch: {batch}, loss: {loss.item()}')
model.eval()
preds = []
targets = []
for inputs, labels in dev_dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
outputs = model(inputs)[0]
preds.extend(get_chunks(outputs))
targets.extend(get_chunks(labels))
f1 = f1_score(targets, preds)
print(f'Epoch {epoch}: f1 score = {f1}')
# 保存训练好的模型
torch.save(model.state_dict(), 'model.bin')
```
其中,`MyDataset` 类是一个自定义的数据集类,可以根据自己的数据格式进行修改。`get_tags` 函数用于获取所有的标签,`get_entities` 函数用于获取所有的实体类型,`get_chunks` 函数用于将模型输出的标签转换为实体序列,`get_labels` 函数用于获取所有的标签,这些函数都是在 `utils.py` 中实现的,可以自行实现。
阅读全文