plmee中文事件抽取代码‘
时间: 2024-05-14 10:15:37 浏览: 133
代码中文抽取器
以下是一个简单的基于BERT+BiLSTM+CRF模型的中文事件抽取代码示例,使用了pytorch和transformers库:
```python
import torch
from transformers import BertTokenizer, BertModel
import torch.nn as nn
class BertBiLSTMCRF(nn.Module):
def __init__(self, bert_model, num_tags, hidden_size=768, dropout_prob=0.1):
super(BertBiLSTMCRF, self).__init__()
self.bert = bert_model
self.bilstm = nn.LSTM(hidden_size, hidden_size // 2, num_layers=1,
bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(hidden_size, num_tags)
self.crf = CRF(num_tags, batch_first=True)
def forward(self, input_ids, attention_mask):
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = bert_outputs[0]
sequence_output = self.dropout(sequence_output)
lstm_output, _ = self.bilstm(sequence_output)
lstm_output = self.dropout(lstm_output)
emissions = self.fc(lstm_output)
mask = attention_mask.bool()
crf_output = self.crf.decode(emissions, mask)
return crf_output
class CRF(nn.Module):
def __init__(self, num_tags, batch_first=False):
super(CRF, self).__init__()
self.num_tags = num_tags
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
self.start_transitions = nn.Parameter(torch.randn(num_tags))
self.end_transitions = nn.Parameter(torch.randn(num_tags))
self.batch_first = batch_first
def forward(self, emissions, tags, mask):
if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
sequence_length = emissions.shape[0]
batch_size = emissions.shape[1]
score = self.start_transitions.view(1, -1) + emissions[0]
for i in range(1, sequence_length):
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)
next_score = broadcast_score + self.transitions + broadcast_emissions
next_score = torch.logsumexp(next_score, dim=1)
mask_idx = mask[i].unsqueeze(1).expand(batch_size, self.num_tags)
score = torch.where(mask_idx, next_score, score)
score = score + self.end_transitions.view(1, -1)
score = torch.logsumexp(score, dim=1)
gold_score = self._score_sentence(emissions, tags, mask)
return (score - gold_score) / batch_size
def decode(self, emissions, mask):
if self.batch_first:
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)
sequence_length = emissions.shape[0]
batch_size = emissions.shape[1]
score = self.start_transitions.view(1, -1) + emissions[0]
history = []
for i in range(1, sequence_length):
broadcast_score = score.unsqueeze(2)
broadcast_emissions = emissions[i].unsqueeze(1)
next_score = broadcast_score + self.transitions + broadcast_emissions
next_score, indices = torch.max(next_score, dim=1)
history.append(indices)
mask_idx = mask[i].unsqueeze(1).expand(batch_size, self.num_tags)
score = torch.where(mask_idx, next_score, score)
score = score + self.end_transitions.view(1, -1)
_, best_tag = torch.max(score, dim=1)
best_path = [best_tag]
for h in reversed(history):
best_tag = torch.gather(h, 1, best_tag.unsqueeze(1)).squeeze()
best_path.insert(0, best_tag)
best_path = torch.stack(best_path).transpose(0, 1)
return best_path
def _score_sentence(self, emissions, tags, mask):
if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
sequence_length = emissions.shape[0]
batch_size = emissions.shape[1]
score = self.start_transitions[tags[0]]
for i in range(1, sequence_length):
current_tags = tags[i]
previous_tags = tags[i - 1]
transition_score = self.transitions[previous_tags, current_tags]
emission_score = emissions[i, torch.arange(batch_size), current_tags]
mask_idx = mask[i]
score = score + transition_score * mask_idx + emission_score * mask_idx
last_tag_indexes = mask.sum(dim=0) - 1
last_tags = tags[last_tag_indexes, torch.arange(batch_size)]
last_transition_score = self.end_transitions[last_tags]
last_mask_idx = mask.sum(dim=0).float()
last_score = last_transition_score + last_mask_idx
return score + last_score
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertModel.from_pretrained('bert-base-chinese')
num_tags = 4 # 事件类型数
class EventExtractor:
def __init__(self, model_path='event_extractor.pt'):
self.model = BertBiLSTMCRF(model, num_tags)
self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
self.model.eval()
def extract(self, text):
input_ids = tokenizer.encode(text, add_special_tokens=True)
attention_mask = [1] * len(input_ids)
with torch.no_grad():
pred_tags = self.model(torch.tensor(input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0))
pred_tags = pred_tags[0]
tags = [idx2tag[i] for i in pred_tags]
entities = []
for i, tag in enumerate(tags):
if tag.startswith('B-'):
entity = {'start': i, 'type': tag[2:]}
j = i + 1
while j < len(tags) and tags[j] == 'I-' + tag[2:]:
j += 1
entity['end'] = j - 1
entity['word'] = tokenizer.decode(input_ids[entity['start']:entity['end']+1]).replace(' ', '')
entities.append(entity)
return entities
tag2idx = {'O': 0, 'B-LOC': 1, 'B-PER': 2, 'B-ORG': 3, 'I-LOC': 4, 'I-PER': 5, 'I-ORG': 6}
idx2tag = {idx: tag for tag, idx in tag2idx.items()}
extractor = EventExtractor()
```
在上述代码中,我们使用了BERT作为输入特征提取器,然后将其输出送入一个双向LSTM网络中进行序列标注,最后使用CRF层进行整个序列的解码,得到最终的事件抽取结果。具体实现细节可以参考代码中的注释。
阅读全文