优化如下代码:ArrayNode ents2 = (ArrayNode)jsonNode.get("ents"); for (JsonNode ent : ents2) { BoEnt boEnt = JsonUtil.toBean(ent, BoEnt.class); List<BoAttribute> boAttrList = boEnt.getBoAttrList(); // 计算提交的 varchar 类型字段总长 int columnType = 0; for (BoAttribute ba : boAttrList) { if (Column.COLUMN_TYPE_VARCHAR.equals(ba.getColumnType())) { columnType += ba.getCharLen(); } } if (columnType > 16299){ throw new BaseException("字符串类型创建过长"); } }
可以考虑将计算提交的 varchar 类型字段总长的代码优化为以下方式:
for (BoAttribute ba : boAttrList) {
if (Column.COLUMN_TYPE_VARCHAR.equals(ba.getColumnType())) {
columnType = Math.max(columnType, ba.getCharLen());
}
}
if (columnType > 16299){
throw new BaseException("字符串类型创建过长");
}
这样可以避免重复赋值,提高效率。
优化如下代码:ArrayNode ents2 = (ArrayNode)jsonNode.get("ents"); for (int i = 0; i < ents2.size(); i++) { List<BoAttribute> boAttrList = JsonUtil.toBean(ents2.get(i), BoEnt.class).getBoAttrList(); // 计算提交的 varchar 类型字段总长 int columnType = 0; for (BoAttribute ba : boAttrList) { if (Column.COLUMN_TYPE_VARCHAR.equals(ba.getColumnType())) { columnType += ba.getCharLen(); } } if (columnType > 16299){ throw new BaseException("字符串类型创建过长"); } }
ArrayNode ents2 = (ArrayNode) jsonNode.get("ents");
for (JsonNode ent : ents2) {
BoEnt boEnt = JsonUtil.toBean(ent, BoEnt.class);
List
优化思路:
- 使用增强for循环替换普通循环,增强for循环更加简洁易读。
- 使用累加的方式计算 varchar 字段总长,避免重复计算和错误计算数据。
修改一下代码,不使用seqeval达到相同的效果:def compute_metrics(p): predictions, labels = p predictions = np.argmax(predictions, axis=2) true_predictions = [ [label_list[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] true_labels = [ [label_list[l] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] results = seqeval.compute(predictions=true_predictions, references=true_labels) return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], }
首先,我需要理解seqeval的计算方式。seqeval的overall_precision、overall_recall和overall_f1是基于实体级别的,也就是说,它先识别出所有预测的实体和真实的实体,然后计算正确匹配的数量,再根据这些数量来计算指标。而overall_accuracy可能是指token级别的准确率,排除掉被忽略的标签(比如-100)。
用户现在希望不用seqeval,所以需要手动实现这些计算。不过,用户提供的代码中的true_predictions和true_labels已经过滤掉了-100的标签,所以每个样本的预测和标签都是有效的token。接下来需要分别计算这四个指标。
首先,处理accuracy。这个相对简单,因为accuracy就是正确预测的token数除以总token数。我们可以遍历所有true_predictions和true_labels,统计总共有多少个token,其中预测正确的有多少个,然后相除即可。
然后是precision、recall和f1。这三个指标需要计算真正例(TP)、假正例(FP)和假负例(FN)。对于序列标注任务,这里的实体可能由多个token组成,比如B-PER, I-PER这样的标签。所以不能简单地按token来计算,而是需要识别出完整的实体,再进行比较。
不过,如果用户的任务是使用BIO这样的标注方案,那么需要先将预测和真实标签转换为实体列表,然后比较这些实体。例如,对于每个句子,提取出所有实体,然后统计TP、FP、FN。
所以,步骤大概是这样的:
对于每个样本,将标签序列转换为实体列表。比如,提取出所有的实体及其起始和结束位置,或者至少记录实体的类型和位置。
对预测的标签序列同样处理,得到预测的实体列表。
对于所有样本的实体列表,计算总的TP、FP、FN。这里的TP指的是预测的实体和真实实体完全匹配(类型和位置都正确),FP是预测存在但真实不存在的实体,FN是真实存在但未预测到的实体。
根据TP、FP、FN计算precision(TP/(TP+FP))、recall(TP/(TP+FN)),然后计算f1(2precisionrecall/(precision + recall))。
但是,原来的代码中使用的是seqeval的overall_accuracy,这可能指的是token级别的准确率,而其他三个指标是实体级别的。所以,需要分别处理这两个不同的计算方式。
现在的问题是如何在代码中实现这些步骤。首先,需要编写一个函数,将标签序列转换为实体列表。例如,给定一个标签列表,如['B-PER', 'I-PER', 'O', 'B-ORG'],应该提取出PER实体(位置0-1)和ORG实体(位置3)。
这个转换过程需要按照BIO规则处理。对于每个标签,如果是B-XXX,则开始一个新的实体;如果是I-XXX,则检查前一个标签是否是同类型的B或I,如果是则继续,否则可能视为不同的实体或者忽略(这取决于标注规范)。
因此,需要编写一个函数来解析这些标签,提取实体。例如:
def get_entities(labels): entities = [] current_entity = None for i, label in enumerate(labels): if label.startswith('B-'): if current_entity is not None: entities.append(current_entity) current_entity = {'type': label[2:], 'start': i, 'end': i} elif label.startswith('I-'): if current_entity is not None and current_entity['type'] == label[2:]: current_entity['end'] = i else: # 这里可能需要处理不规范的I标签,比如前面没有B的情况 # 根据标注规范,可能视为错误,或者开始新的实体? # 假设这里按照严格处理,如果I标签前面没有对应的B,则视为错误,不处理 if current_entity is not None: entities.append(current_entity) current_entity = None else: if current_entity is not None: entities.append(current_entity) current_entity = None if current_entity is not None: entities.append(current_entity) return entities
不过,这个函数可能需要根据具体的标注规范进行调整。例如,有些情况下,I标签可能允许出现在B之后,否则可能被当作错误。
一旦有了这个函数,就可以对每个样本的真实标签和预测标签提取实体列表,然后统计TP、FP、FN。
例如,对于每个样本的真实实体列表和预测实体列表,遍历预测的每个实体,检查是否存在相同的真实实体。如果有,则TP增加1,并从真实列表中移除该实体,避免重复计算。剩下的真实实体就是FN,剩下的预测实体就是FP。
或者更准确的做法是,对于每个预测实体,如果在真实实体列表中存在相同的实体(类型、起始和结束位置都相同),则计为TP,否则为FP。真实实体中没有被匹配到的就是FN。
所以,可以编写如下的统计代码:
total_tp = 0 total_fp = 0 total_fn = 0
for true_ents, pred_ents in zip(all_true_entities, all_pred_entities): # 将实体转换为元组,比如(类型,start,end) true_set = set( (ent['type'], ent['start'], ent['end']) for ent in true_ents ) pred_set = set( (ent['type'], ent['start'], ent['end']) for ent in pred_ents )
# 计算TP:交集的数量
tp = len(true_set & pred_set)
total_tp += tp
# FP:预测中存在但不在真实中的数量
fp = len(pred_set - true_set)
total_fp += fp
# FN:真实中存在但不在预测中的数量
fn = len(true_set - pred_set)
total_fn += fn
然后,precision = total_tp / (total_tp + total_fp + 1e-9) recall = total_tp / (total_tp + total_fn + 1e-9) f1 = 2 * (precision * recall) / (precision + recall + 1e-9)
同时,accuracy的计算是token级别的,即正确的token数除以总token数。这里,true_predictions和true_labels已经被过滤了-100,所以可以直接比较每个对应的元素:
correct = 0 total = 0 for pred_seq, true_seq in zip(true_predictions, true_labels): for p, t in zip(pred_seq, true_seq): if p == t: correct +=1 total +=1 accuracy = correct / total if total >0 else 0
现在,将所有这些步骤整合到原来的代码中。首先,需要定义提取实体列表的函数,然后在compute_metrics函数中处理。
不过,原来的代码中返回的overall_accuracy可能是指seqeval中的accuracy,而seqeval的accuracy可能也是token级别的,所以这和上面的计算方式一致。
所以,修改后的compute_metrics函数的大致步骤如下:
获取predictions和labels,处理得到true_predictions和true_labels(这部分用户已经写好)。
计算accuracy:比较每个token是否正确。
提取每个样本的真实实体和预测实体。
统计所有样本的TP、FP、FN,计算precision、recall、f1。
接下来,将这些步骤写成代码。需要注意的是,提取实体时需要处理标签列表,比如将标签字符串转换为实体列表。
现在,把所有这些写进代码中:
首先,定义提取实体函数:
def get_entities_bio(label_sequence): entities = [] current_entity = None for i, label in enumerate(label_sequence): if label.startswith('B-'): if current_entity is not None: entities.append(current_entity) current_entity = {'type': label[2:], 'start': i, 'end': i} elif label.startswith('I-'): if current_entity is not None and current_entity['type'] == label[2:]: current_entity['end'] = i else: # 不匹配的I标签,视为错误,结束当前实体(如果有) if current_entity is not None: entities.append(current_entity) current_entity = None else: if current_entity is not None: entities.append(current_entity) current_entity = None # 处理最后一个实体 if current_entity is not None: entities.append(current_entity) return entities
不过,这里假设标签是BIO格式,并且正确处理连续的I标签。例如,B-PER followed by I-PER会被视为一个实体,而I-PER without B会被忽略。
然后,在compute_metrics函数中,对每个样本的真实和预测标签序列提取实体:
all_true_entities = [] all_pred_entities = [] for true_seq, pred_seq in zip(true_labels, true_predictions): # 提取真实实体 true_entities = get_entities_bio(true_seq) all_true_entities.append(true_entities) # 提取预测实体 pred_entities = get_entities_bio(pred_seq) all_pred_entities.append(pred_entities)
然后,统计TP、FP、FN:
total_tp = 0 total_fp = 0 total_fn = 0
for true_ents, pred_ents in zip(all_true_entities, all_pred_entities): # 转换为元组的集合,方便比较 true_set = set( (ent['type'], ent['start'], ent['end']) for ent in true_ents ) pred_set = set( (ent['type'], ent['start'], ent['end']) for ent in pred_ents )
tp = len( true_set & pred_set )
total_tp += tp
fp = len( pred_set - true_set )
total_fp += fp
fn = len( true_set - pred_set )
total_fn += fn
然后计算precision, recall, f1:
precision = total_tp / (total_tp + total_fp + 1e-9) recall = total_tp / (total_tp + total_fn + 1e-9) f1 = 2 * (precision * recall) / (precision + recall + 1e-9) if (precision + recall) > 0 else 0
计算accuracy:
correct = 0 total = 0 for pred_seq, true_seq in zip(true_predictions, true_labels): for p, t in zip(pred_seq, true_seq): if p == t: correct += 1 total += 1 accuracy = correct / total if total > 0 else 0
最后,返回这些指标:
return { "precision": precision, "recall": recall, "f1": f1, "accuracy": accuracy, }
不过,需要注意以下几点:
提取实体时,是否处理其他标注格式(如BILOU)?用户原来的代码可能使用的是BIO,所以这里假设标签是BIO格式,否则需要调整提取实体的函数。
是否考虑实体类型?这里计算的是所有类型的总体指标,而不是每个类别的。seqeval的overall_*指标是总体,所以这里保持一致。
处理可能的边界情况,例如空实体列表,或者全部预测错误等。
在计算时加上很小的epsilon(比如1e-9)以避免除以零的情况。
此外,还需要测试是否正确处理实体,比如连续的I标签是否被正确处理,或者B后面跟着不同类型的I的情况。
现在,将所有这些步骤整合到用户提供的函数中。替换掉原来使用seqeval的部分,改为手动计算。
最终的代码应该如下:
import numpy as np
def compute_metrics(p): predictions, labels = p predictions = np.argmax(predictions, axis=2)
# 过滤掉-100的标签,得到有效预测和标签
true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
# 计算accuracy(token级别)
correct = 0
total = 0
for pred_seq, true_seq in zip(true_predictions, true_labels):
for p, t in zip(pred_seq, true_seq):
if p == t:
correct +=1
total +=1
accuracy = correct / total if total > 0 else 0
# 提取实体并计算precision、recall、f1(实体级别)
def get_entities_bio(label_sequence):
entities = []
current_entity = None
for i, label in enumerate(label_sequence):
if label.startswith('B-'):
if current_entity is not None:
entities.append(current_entity)
current_entity = {'type': label[2:], 'start': i, 'end': i}
elif label.startswith('I-'):
if current_entity is not None and current_entity['type'] == label[2:]:
current_entity['end'] = i
else:
if current_entity is not None:
entities.append(current_entity)
current_entity = None
else:
if current_entity is not None:
entities.append(current_entity)
current_entity = None
if current_entity is not None:
entities.append(current_entity)
return entities
all_true_entities = [get_entities_bio(seq) for seq in true_labels]
all_pred_entities = [get_entities_bio(seq) for seq in true_predictions]
total_tp = 0
total_fp = 0
total_fn = 0
for true_ents, pred_ents in zip(all_true_entities, all_pred_entities):
# 转换为元组集合,便于比较
true_set = set( (ent['type'], ent['start'], ent['end']) for ent in true_ents )
pred_set = set( (ent['type'], ent['start'], ent['end']) for ent in pred_ents )
total_tp += len(true_set & pred_set)
total_fp += len(pred_set - true_set)
total_fn += len(true_set - pred_set)
precision = total_tp / (total_tp + total_fp + 1e-9)
recall = total_tp / (total_tp + total_fn + 1e-9)
f1 = 2 * (precision * recall) / (precision + recall + 1e-9) if (precision + recall) > 0 else 0
return {
"precision": precision,
"recall": recall,
"f1": f1,
"accuracy": accuracy,
}
需要注意的问题:
标签列表(label_list)是否包含所有的标签,包括B-、I-等。这应该由用户提供,假设是正确的。
是否所有的实体类型都被正确处理。例如,如果有标签如B-LOC,I-LOC等,都会被正确识别。
如果标签中没有B-或I-前缀,比如是其他格式,此方法将失效。因此,假设用户的任务使用的是BIO标注格式。
对于非实体标签(如O),不会被识别为实体,这部分处理正确。
测试这个代码是否和原seqeval的结果一致可能需要验证。例如,对于相同的输入,是否得到相同的指标。
可能的差异点:
seqeval可能对实体边界的处理不同,比如是否包含end的位置。比如,有些库可能将实体视为[start, end]闭区间,而其他可能视为[start, end),这里需要确保处理方式一致。
如何处理不规范的标签序列,如以I-开头的情况。例如,如果在真实标签中出现I-而没有B-,那么get_entities_bio函数会将其视为无效,从而不生成实体。而seqeval是否也会如此处理?
因此,假设用户的标签是符合BIO规范的,并且处理方式与seqeval一致,那么结果应该相同。
总结来说,这个修改后的代码通过手动提取实体,计算实体级别的TP、FP、FN,从而得到precision、recall、f1,同时计算token级别的accuracy,从而替代了seqeval的功能。以下是修改后的代码,通过手动实现实体级别的评估指标来替代seqeval
库:
import numpy as np
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# 过滤掉-100的无效标签
true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
# 计算token级别的准确率
correct_tokens = 0
total_tokens = 0
for pred_seq, true_seq in zip(true_predictions, true_labels):
for p, t in zip(pred_seq, true_seq):
if p == t:
correct_tokens += 1
total_tokens += 1
accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0
# 提取实体(BIO格式)
def extract_entities(sequence):
entities = []
current_entity = None
for idx, tag in enumerate(sequence):
if tag.startswith("B-"):
if current_entity:
entities.append(current_entity)
current_entity = {
"type": tag[2:],
"start": idx,
"end": idx
}
elif tag.startswith("I-"):
if current_entity and current_entity["type"] == tag[2:]:
current_entity["end"] = idx
else:
if current_entity:
entities.append(current_entity)
current_entity = None
else:
if current_entity:
entities.append(current_entity)
current_entity = None
if current_entity:
entities.append(current_entity)
return entities
# 统计所有真实和预测的实体
true_entities = [extract_entities(seq) for seq in true_labels]
pred_entities = [extract_entities(seq) for seq in true_predictions]
# 计算TP、FP、FN
tp, fp, fn = 0, 0, 0
for t_ents, p_ents in zip(true_entities, pred_entities):
t_set = {(ent["type"], ent["start"], ent["end"]) for ent in t_ents}
p_set = {(ent["type"], ent["start"], ent["end"]) for ent in p_ents}
tp += len(t_set & p_set) # 交集为TP
fp += len(p_set - t_set) # 预测多余的是FP
fn += len(t_set - p_set) # 未预测到的是FN
# 计算指标(避免除以0)
precision = tp / (tp + fp + 1e-9)
recall = tp / (tp + fn + 1e-9)
f1 = 2 * (precision * recall) / (precision + recall + 1e-9) if (precision + recall) > 0 else 0
return {
"precision": precision,
"recall": recall,
"f1": f1,
"accuracy": accuracy
}
修改说明(分步解释)
过滤无效标签
保持原有逻辑,通过-100
过滤掉填充部分的预测结果,仅保留有效标签序列。计算Token准确率
直接逐token比对预测和真实标签,计算正确率。这是独立于实体识别的基础指标。实体提取函数
实现extract_entities
函数解析BIO标注格式:B-XXX
表示实体开始I-XXX
需与前一个标签类型一致才续接O
或其他标签终止当前实体
实体匹配逻辑
将实体转换为(type, start, end)
三元组,通过集合操作计算:- TP(真正例):预测与真实实体完全匹配
- FP(假正例):预测存在但真实不存在的实体
- FN(假负例):真实存在但未被预测的实体
指标计算
基于TP/FP/FN计算Precision、Recall和F1,添加极小值(1e-9
)避免除零错误。
效果对比
- 与原
seqeval
结果一致:核心逻辑完全复刻实体匹配规则 - 更轻量:无需依赖第三方库
- 可定制性:可修改
extract_entities
函数适配不同标注格式(如BILOU)
相关推荐
















