BERT+CRF实现命名体识别任务的代码
时间: 2023-12-09 16:03:44 浏览: 82
下面是一个使用BERT+CRF实现命名实体识别任务的基本代码,你可以根据自己的需求进行修改和完善:
```python
import torch
import torch.nn as nn
from transformers import BertModel
class BertCRF(nn.Module):
def __init__(self, num_tags):
super(BertCRF, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-chinese')
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(768, num_tags)
self.crf = CRF(num_tags)
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
emissions = self.fc(sequence_output)
if labels is None:
return self.crf.decode(emissions, attention_mask)
else:
loss = self.crf(emissions, labels, mask=attention_mask.byte())
return -loss
class CRF(nn.Module):
def __init__(self, num_tags):
super(CRF, self).__init__()
self.num_tags = num_tags
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
def forward(self, emissions, tags, mask=None):
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.uint8, device=emissions.device)
else:
mask = mask.bool()
scores = self._compute_scores(emissions, tags, mask)
partition = self._compute_log_partition(emissions, mask)
return (partition - scores) / emissions.shape[0]
def decode(self, emissions, mask=None):
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.uint8, device=emissions.device)
else:
mask = mask.bool()
scores, sequences = self._viterbi_decode(emissions, mask)
return sequences
def _compute_scores(self, emissions, tags, mask):
scores = self.transitions[tags[:, :-1], tags[:, 1:]].masked_fill(~(mask[:, 1:]), 0).sum(dim=1)
scores += emissions.gather(dim=2, index=tags.unsqueeze(-1)).squeeze(-1).masked_fill(~mask, 0).sum(dim=1)
return scores
def _compute_log_partition(self, emissions, mask):
alpha = emissions[:, 0]
for i in range(1, emissions.shape[1]):
emit_score = emissions[:, i]
trans_score = self.transitions.unsqueeze(0)
mask_t = mask[:, i].unsqueeze(-1)
alpha = (log_sum_exp(trans_score + emit_score.unsqueeze(1) + alpha.unsqueeze(2)) * mask_t).squeeze(-1)
return log_sum_exp(alpha)
def _viterbi_decode(self, emissions, mask):
alpha = emissions[:, 0]
pre_ids = []
for i in range(1, emissions.shape[1]):
emit_score = emissions[:, i]
trans_score = self.transitions.unsqueeze(0)
mask_t = mask[:, i].unsqueeze(-1)
scores = trans_score + emit_score.unsqueeze(1) + alpha.unsqueeze(2)
max_scores, ids = scores.max(dim=1)
pre_ids.append(ids)
alpha = (max_scores * mask_t).squeeze(-1) + alpha.masked_fill(~mask_t.squeeze(-1), 0)
pre_ids.append(alpha.argmax(dim=1, keepdim=True))
pre_ids.reverse()
pre_ids = torch.cat(pre_ids, dim=1)
return alpha.max(dim=1), pre_ids
def log_sum_exp(x, dim=None):
x_max = x.max(dim=dim, keepdim=True)[0]
x = x - x_max
if dim is not None:
return x_max.squeeze(dim) + (x.exp().sum(dim=dim)).log()
else:
return x_max.squeeze() + (x.exp().sum()).log()
```
其中,BertCRF类封装了BERT和CRF模型,CRF类实现了CRF层的计算和预测方法,log_sum_exp函数实现了对数求和的计算。
这个代码可以用于中文命名实体识别任务,需要提供训练数据集和测试数据集,并使用torch.utils.data.DataLoader进行数据加载和批处理。你可以根据自己的需求对数据处理和模型配置进行修改。
阅读全文