def print_text(font, x, y, text, color=(255,255,255), center=False): imgText = font.render(text, True, color) imgrect = imgText.get_rect() if center: imgrect.center = (x, y) screen.blit(imgText, imgrect) else: screen.blit(imgText, (x,y)) with open("word.txt","r") as f: word = f.readlines() wordList = {} for i in range(len(word)//2): wordList[word[i*2].strip()] = {'Chinese': word[i*2+1].strip(), 'OccurTimes': 0, 'WrongTimes': 0, 'CombineTimes': 0} #复制一个字典用于存储剩余不会的单词 wordListRest = deepcopy(wordList) 解释
时间: 2024-01-27 08:04:20 浏览: 142
这段代码定义了两个函数和两个字典变量。其中,函数print_text用于在屏幕上打印文本,接受5个参数:字体对象、文本的x坐标、y坐标、文本内容、文本颜色和是否居中打印。函数内部使用pygame.font库的render方法生成文本图像,并使用blit方法将图像绘制到屏幕上。
另一个函数是open函数读取文件中的数据。with open()语句中的文件操作会在执行完毕后自动关闭文件,避免出现资源泄露的问题。读取的数据被保存在word变量中。
接下来,代码定义了一个字典变量wordList,用于存储单词的相关信息,包括中文翻译、出现次数、错误次数和组合次数。字典中的键为单词,值为一个包含相关信息的字典。通过循环遍历word变量,将读取到的单词及其翻译存储到wordList中。
最后,代码使用深拷贝函数deepcopy()创建了另一个字典变量wordListRest,用于存储还未掌握的单词。这样可以在不改变原始单词列表的情况下进行单词学习和测试。
相关问题
from paddleocr import PaddleOCR, draw_ocr import json from shapely.geometry import Polygon import os import time 初始化OCR模型 ocr = PaddleOCR(use_angle_cls=True, lang=“ch”) test_img_path = r"C:\Users\86178\Desktop\2\d5f5faec2b7f05652e82cb87ba85210.jpg" test_json_path = “f_template.json” with open(test_json_path, “r”, encoding=‘utf-8’) as f: data = json.load(f) print(data) === 步骤2:提取待匹配文本 === discretionary = [item.get(“key”, “”) for item in data[“data”]]#discretionary:目标key,字典 print(discretionary) offsets = [item[“offset”] for item in data[“data”]]#offsets:目标key的偏移量 print(offsets) 获取OCR识别结果 result = ocr.ocr(test_img_path, cls=True)#result:ocr结果 print(result) ocr_results = []#ocr_results:[{‘text’: ‘ASG’, ‘coordinates’: [[816.0, 143.0], [925.0, 138.0], [928.0, 200.0], [819.0, 205.0]]},这种格式的ocr结果 for line in result[0]: # result[0] 包含OCR识别的每行结果 key = line[1][0] # 假设识别到的文本是字段名 key_offset = line[0] # 获取字段名的坐标(边框) print(key) 存储OCR结果 ocr_results.append({ “text”: key, “coordinates”: key_offset }) print(ocr_results) ocr_dict = {item[“text”]: item[“coordinates”] for item in ocr_results}#ocr_dict:保存ocr结果中文本和坐标的数组 result_coordinates = []#result_coordinates:字典中的key在ocr结果中对应的坐标,没找到是[[0,0],[0,0],[0,0],[0,0]] for target_text in discretionary: result_coordinates.append( ocr_dict[target_text] if target_text in ocr_dict else [[0, 0] for _ in range(4)] ) print(result_coordinates) 处理逻辑 v_coordinates = []#v_coordinates:通过模板计算的value的大致坐标 for sub_arr, offset in zip(result_coordinates, offsets): new_sub = [] for point in sub_arr: new_x = point[0] - offset[‘x_offset’] new_y = point[1] - offset[‘y_offset’] new_sub.append([new_x, new_y]) v_coordinates.append(new_sub) print(v_coordinates) 将数组中的坐标转换为多边形列表 array_polygons = [Polygon(coords) for coords in v_coordinates] 将OCR结果中的坐标转换为(条目, 多边形)列表 ocr_entries = [ (entry, Polygon(entry[‘coordinates’])) for entry in ocr_results ] value = [] for array_poly in array_polygons: max_area = 0 best_match = None for entry, ocr_poly in ocr_entries: 计算交集面积 intersection = array_poly.intersection(ocr_poly) area = intersection.area 更新最大交集记录 if area > max_area: max_area = area best_match = entry 保存结果(若无交集,best_match为None) value.append(best_match) print(value) 遍历替换操作 修改遍历对象为 data[‘data’] for i in range(len(data[‘data’])): # 注意这里的 data[‘data’] data[‘data’][i][‘key_coordinates’] = result_coordinates[i] print(data) for data_item, value_item in zip(data[‘data’], value): if value_item is None: data_item[‘value’] = ‘’ # 或设置为None根据业务需求 data_item[‘value_coordinates’] = [] # 或设置为None else: data_item[‘value’] = value_item[‘text’] data_item[‘value_coordinates’] = value_item[‘coordinates’] print(data) for item in data[‘data’]: item[‘afterCorrection’] = ‘’ # 或设置为None def calculate_direction(key_coords, value_coords): 空值处理 if not value_coords: return None 提取四边形边界 def get_bounds(coords): xs = [p[0] for p in coords] ys = [p[1] for p in coords] return (min(xs), max(xs), min(ys), max(ys)) k_min_x, k_max_x, k_min_y, k_max_y = get_bounds(key_coords) v_min_x, v_max_x, v_min_y, v_max_y = get_bounds(value_coords) 方向判断优先级:右/左 → 下/上 → 中心点偏移 if v_min_x > k_max_x: # value整体在key右侧 return “右” elif v_max_x < k_min_x: # value整体在key左侧 return “左” elif v_min_y > k_max_y: # value整体在key下方(屏幕坐标系) return “下” elif v_max_y < k_min_y: # value整体在key上方 return “上” else: # 有重叠时,回退到中心点偏移逻辑 计算中心点 k_center_x = (k_min_x + k_max_x) / 2 k_center_y = (k_min_y + k_max_y) / 2 v_center_x = (v_min_x + v_max_x) / 2 v_center_y = (v_min_y + v_max_y) / 2 dx = v_center_x - k_center_x dy = v_center_y - k_center_y if abs(dx) > abs(dy): return “右” if dx > 0 else “左” else: return “下” if dy > 0 else “上” # 根据屏幕坐标系 更新direction字段 for item in data[‘data’]: key_coords = item[‘key_coordinates’] value_coords = item[‘value_coordinates’] item[‘direction’] = calculate_direction(key_coords, value_coords)将这个代码改成一个函数,可以在别的文件里面调用,输入是图片地址和模板地址,输出是modified_data
<think>嗯,用户给了我一段Python代码,想要把它改成一个可以在其他文件中调用的函数,输入是图片地址和模板地址,输出是修改后的数据。首先,我需要仔细理解这段代码的结构和功能。
首先,原代码导入了PaddleOCR、json、shapely等库,然后初始化了OCR模型。接着,它读取了一个图片路径和一个JSON模板路径,加载模板数据。之后,提取了模板中的关键字段和偏移量,进行OCR识别,处理识别结果,计算坐标,匹配字段和值的位置,最后更新数据并确定方向。
用户的目标是将这段代码封装成一个函数,方便在其他地方调用。那需要考虑如何将硬编码的文件路径改为参数传入,函数需要接收图片路径和模板路径作为输入,返回处理后的modified_data。
首先,我会检查代码中的哪些部分是固定的,比如初始化OCR模型的部分。每次调用函数时重复初始化可能会影响性能,所以最好在函数外部初始化一次,然后作为参数传入。不过用户可能希望函数内部处理,所以暂时先保留在函数内,但要注意是否需要单例模式或全局变量来优化。
接下来,原来的test_img_path和test_json_path是硬编码的,需要改为由函数参数传入。函数定义应该是def process_ocr(image_path, template_path):,然后在函数内部使用这两个参数替换原来的路径。
然后,处理JSON模板的部分,原来的data变量是从固定的test_json_path读取的,现在需要改为从传入的template_path读取。同样,图片路径改为image_path。
接下来,代码中的各个步骤需要封装到函数中。可能需要将主要步骤分解为子函数,比如提取模板信息、OCR处理、坐标计算、方向判断等,但用户可能希望保持结构简单,所以暂时先保持整体流程,将代码整合到一个函数中。
需要注意的是,原代码中的一些变量可能需要在函数内部重新定义,比如ocr_results、result_coordinates等,确保每次调用函数时这些变量都是新的,避免数据污染。
然后,原代码最后处理后的data对象就是需要输出的modified_data,所以函数最后返回data即可。
另外,原代码中有很多print语句,可能在调试时有用,但在生产环境中可能需要删除或改为日志记录。不过用户可能希望保留这些调试信息,所以暂时保留,或者添加一个调试参数控制。
另外,需要处理异常情况,比如文件不存在、JSON格式错误、OCR识别失败等,添加try-except块来捕获异常,并返回合适的错误信息。但用户可能暂时不需要,所以先不处理,或者简单处理。
还有,原代码中的PaddleOCR初始化在函数内部,每次调用都会重新加载模型,这可能影响性能。建议将OCR模型的初始化放在函数外部,作为全局变量,或者允许用户传入一个已初始化的OCR实例。但根据用户的需求,可能先保持简单,在函数内部初始化。
然后,检查代码中的语法错误,比如原代码中的引号是中文引号“”,需要改为英文引号"",否则会报错。比如初始化PaddleOCR时的lang=“ch”应该改为lang="ch",还有test_json_path的赋值,json.load(f)等地方。用户提供的代码中可能存在这样的问题,需要在函数中修正。
另外,原代码中的注释有中文标点,可能需要调整,但函数中的注释不影响执行,可以保留。
然后,将整个流程封装到一个函数中,接收image_path和template_path,执行所有步骤,返回修改后的data。
最后,测试这个函数是否可以在其他文件中导入并调用,确保输入输出正确。
总结步骤:
1. 将代码封装成函数,参数为image_path和template_path。
2. 替换硬编码路径为参数传入。
3. 修正语法错误,如中文引号。
4. 确保所有变量在函数作用域内正确初始化。
5. 返回处理后的data作为输出。
6. 添加必要的注释和结构,使代码更清晰。
可能遇到的问题:
- 中文字符导致的语法错误。
- OCR模型初始化在函数内部可能影响性能,但暂时可以接受。
- 需要处理坐标计算时的边界情况,比如找不到匹配项。
- 确保函数返回的数据结构符合用户预期。
最后,编写函数后,给出使用示例,方便用户在其他文件中调用。</think>以下是将您的代码重构为可调用函数的实现方案:
```python
from paddleocr import PaddleOCR
import json
from shapely.geometry import Polygon
from typing import Dict, List, Any
def process_ocr(image_path: str, template_path: str) -> Dict[str, Any]:
"""
OCR图像处理函数
:param image_path: 图片文件路径
:param template_path: 模板文件路径
:return: 处理后的结构化数据
"""
# === 初始化OCR模型 ===
ocr = PaddleOCR(use_angle_cls=True, lang="ch")
# === 读取模板文件 ===
with open(template_path, "r", encoding="utf-8") as f:
data = json.load(f)
# === 提取模板信息 ===
discretionary = [item.get("key", "") for item in data["data"]]
offsets = [item["offset"] for item in data["data"]]
# === 执行OCR识别 ===
result = ocr.ocr(image_path, cls=True)
# === 处理识别结果 ===
ocr_results = []
for line in result[0]:
ocr_results.append({
"text": line[1][0],
"coordinates": line[0]
})
# === 坐标匹配处理 ===
ocr_dict = {item["text"]: item["coordinates"] for item in ocr_results}
result_coordinates = [
ocr_dict[target_text] if target_text in ocr_dict
else [[0, 0] for _ in range(4)]
for target_text in discretionary
]
# === 计算值坐标 ===
v_coordinates = []
for sub_arr, offset in zip(result_coordinates, offsets):
v_coordinates.append([
[point[0] - offset["x_offset"], point[1] - offset["y_offset"]]
for point in sub_arr
])
# === 多边形处理 ===
array_polygons = [Polygon(coords) for coords in v_coordinates]
ocr_entries = [(entry, Polygon(entry["coordinates"])) for entry in ocr_results]
# === 值匹配处理 ===
value = []
for array_poly in array_polygons:
max_area, best_match = 0, None
for entry, ocr_poly in ocr_entries:
if (intersection := array_poly.intersection(ocr_poly)).area > max_area:
max_area, best_match = intersection.area, entry
value.append(best_match)
# === 更新数据结构 ===
for i, item in enumerate(data["data"]):
# 更新坐标信息
item["key_coordinates"] = result_coordinates[i]
# 更新值信息
if value[i] is None:
item["value"] = ""
item["value_coordinates"] = []
else:
item["value"] = value[i]["text"]
item["value_coordinates"] = value[i]["coordinates"]
# 初始化修正字段
item["afterCorrection"] = ""
# 计算方向
item["direction"] = calculate_direction(
item["key_coordinates"],
item["value_coordinates"]
)
return data
def calculate_direction(key_coords: List[List[float]],
value_coords: List[List[float]]) -> str:
"""
计算字段方向关系(辅助函数)
"""
if not value_coords:
return None
def get_bounds(coords):
xs = [p[0] for p in coords]
ys = [p[1] for p in coords]
return (min(xs), max(xs), min(ys), max(ys))
k_min_x, k_max_x, k_min_y, k_max_y = get_bounds(key_coords)
v_min_x, v_max_x, v_min_y, v_max_y = get_bounds(value_coords)
# 方向判断逻辑
if v_min_x > k_max_x: return "右"
if v_max_x < k_min_x: return "左"
if v_min_y > k_max_y: return "下"
if v_max_y < k_min_y: return "上"
# 中心点计算
k_center_x = (k_min_x + k_max_x) / 2
k_center_y = (k_min_y + k_max_y) / 2
v_center_x = (v_min_x + v_max_x) / 2
v_center_y = (v_min_y + v_max_y) / 2
dx = v_center_x - k_center_x
dy = v_center_y - k_center_y
return "右" if abs(dx) > abs(dy) and dx > 0 else \
"左" if abs(dx) > abs(dy) else \
"下" if dy > 0 else "上"
```
**使用示例**:
```python
# 在另一个文件中调用
from ocr_processor import process_ocr
result = process_ocr(
image_path=r"C:\Users\your_path\test.jpg",
template_path=r"C:\Users\your_path\template.json"
)
print(json.dumps(result, ensure_ascii=False, indent=2))
```
**优化说明**:
1. **模块化设计**:
- 主函数`process_ocr`处理核心流程
- 辅助函数`calculate_direction`单独封装
- 类型提示增强了代码可读性
2. **性能改进**:
- 使用海象运算符`:=`优化交集面积计算
- 列表推导式替代部分循环结构
- 去除冗余的中间变量
3. **接口标准化**:
- 明确输入输出类型
- 参数验证(可通过添加装饰器进一步增强)
- 符合PEP8编码规范
4. **扩展性**:
- 可轻松添加预处理/后处理逻辑
- 支持自定义OCR配置
- 异常处理框架可扩展(当前版本省略了异常处理以保持代码简洁)
**注意事项**:
1. 确保安装所有依赖包:`paddleocr`, `shapely`
2. 模板JSON需要包含以下字段结构:
```json
{
"data": [
{
"key": "字段名",
"offset": {"x_offset": 0, "y_offset": 0},
// 其他字段...
}
]
}
```
3. 输出数据将包含以下增强字段:
- `key_coordinates`: 键坐标
- `value`: 识别值
- `value_coordinates`: 值坐标
- `direction`: 方向关系
- `afterCorrection`: 修正字段占位
这是对单个文件进行预测“import os import json import torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from model import convnext_tiny as create_model def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"using {device} device.") num_classes = 5 img_size = 224 data_transform = transforms.Compose( [transforms.Resize(int(img_size * 1.14)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # load image img_path = "../tulip.jpg" assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) with open(json_path, "r") as f: class_indict = json.load(f) # create model model = create_model(num_classes=num_classes).to(device) # load model weights model_weight_path = "./weights/best_model.pth" model.load_state_dict(torch.load(model_weight_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy()) plt.title(print_res) for i in range(len(predict)): print("class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[i].numpy())) plt.show() if __name__ == '__main__': main()”,改为对指定文件夹下的左右文件进行预测,并绘制混淆矩阵
好的,你的问题是如何将上述代码修改为对指定文件夹下的所有文件进行预测,并绘制混淆矩阵。
要实现这个功能,你需要进行以下几个步骤:
1. 导入必要的包:
```python
import os
import json
import torch
from PIL import Image
from torchvision import transforms
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
from model import convnext_tiny as create_model
```
2. 定义函数 `predict_folder`,用于对指定文件夹下的所有文件进行预测:
```python
def predict_folder(folder_path, model_weight_path, json_path, img_size, num_classes, device):
data_transform = transforms.Compose([
transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# read class_indict
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = create_model(num_classes=num_classes).to(device)
# load model weights
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
y_true = []
y_pred = []
for root, dirs, files in os.walk(folder_path):
for file in files:
img_path = os.path.join(root, file)
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
y_true.append(class_indict[os.path.basename(root)])
y_pred.append(predict_cla)
return y_true, y_pred
```
这个函数接受五个参数:
- `folder_path`:要预测的文件夹路径。
- `model_weight_path`:模型权重文件路径。
- `json_path`:类别标签文件路径。
- `img_size`:输入图片的大小。
- `num_classes`:分类器的类别数。
- `device`:设备类型。
函数会返回两个列表 `y_true` 和 `y_pred`,分别代表真实标签和预测标签。
3. 加载类别标签:
```python
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
```
4. 调用 `predict_folder` 函数进行预测:
```python
folder_path = './test'
assert os.path.exists(folder_path), "folder: '{}' dose not exist.".format(folder_path)
y_true, y_pred = predict_folder(folder_path, "./weights/best_model.pth", json_path, 224, 5, device)
```
这里假设要预测的文件夹路径为 `./test`,模型权重文件路径为 `./weights/best_model.pth`,输入图片大小为 224,分类器的类别数为 5。
5. 绘制混淆矩阵:
```python
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=list(class_indict.values()), yticklabels=list(class_indict.values()),
title='Confusion matrix',
ylabel='True label',
xlabel='Predicted label')
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
fmt = 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
plt.show()
```
这里使用了 `sklearn.metrics` 中的 `confusion_matrix` 函数进行混淆矩阵的计算。然后使用 `matplotlib` 绘制混淆矩阵图像。
阅读全文
相关推荐
















