OCR技术,ctpn+crnn代码实现
时间: 2023-10-15 20:04:26 浏览: 117
OCR技术是一种能够将图像中的文本内容转化为可编辑文本的技术,其中ctpn和crnn是OCR技术中的两个重要组成部分。
ctpn(Connectionist Text Proposal Network)是一种基于深度学习的文本检测算法,其主要任务是检测图像中的文本行和单个字符,并将其转换为一组矩形边界框。这些边界框可以用于后续的文本识别操作。
crnn(Convolutional Recurrent Neural Network)是一种基于深度学习的文本识别算法,其主要任务是根据文本检测阶段生成的文本行或单个字符图像,识别其中的文本内容。crnn算法通常由卷积神经网络(CNN)和循环神经网络(RNN)两个部分组成,其中CNN用于提取图像特征,RNN用于对特征序列进行建模。
以下是一个基于ctpn和crnn的OCR代码实现示例(Python):
```python
import cv2
import numpy as np
import tensorflow as tf
# 加载ctpn模型
ctpn_model = cv2.dnn.readNet('ctpn.pb')
# 加载crnn模型
crnn_model = tf.keras.models.load_model('crnn.h5')
# 定义字符集
charset = '0123456789abcdefghijklmnopqrstuvwxyz'
# 定义字符到索引的映射表
char_to_index = {char: index for index, char in enumerate(charset)}
# 定义CTPN参数
ctpn_params = {
'model': 'ctpn',
'scale': 600,
'max_scale': 1200,
'text_proposals': 2000,
'min_size': 16,
'line_min_score': 0.9,
'text_proposal_min_score': 0.7,
'text_proposal_nms_threshold': 0.3,
'min_num_proposals': 2,
'max_num_proposals': 10
}
# 定义CRNN参数
crnn_params = {
'model': 'crnn',
'img_w': 100,
'img_h': 32,
'num_classes': len(charset),
'rnn_units': 128,
'rnn_dropout': 0.25,
'rnn_recurrent_dropout': 0.25,
'rnn_activation': 'relu',
'rnn_type': 'lstm',
'rnn_direction': 'bidirectional',
'rnn_merge_mode': 'concat',
'cnn_filters': 32,
'cnn_kernel_size': (3, 3),
'cnn_activation': 'relu',
'cnn_pool_size': (2, 2)
}
# 定义文本检测函数
def detect_text(image):
# 将图像转换为灰度图
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 缩放图像
scale = ctpn_params['scale']
max_scale = ctpn_params['max_scale']
if np.max(gray) > 1:
gray = gray / 255
rows, cols = gray.shape
if rows > max_scale:
scale = max_scale / rows
gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale)
rows, cols = gray.shape
elif rows < scale:
scale = scale / rows
gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale)
rows, cols = gray.shape
# 文本检测
ctpn_model.setInput(cv2.dnn.blobFromImage(gray))
output = ctpn_model.forward()
boxes = []
for i in range(output.shape[2]):
score = output[0, 0, i, 2]
if score > ctpn_params['text_proposal_min_score']:
x1 = int(output[0, 0, i, 3] * cols / scale)
y1 = int(output[0, 0, i, 4] * rows / scale)
x2 = int(output[0, 0, i, 5] * cols / scale)
y2 = int(output[0, 0, i, 6] * rows / scale)
boxes.append([x1, y1, x2, y2])
# 合并重叠的文本框
boxes = cv2.dnn.NMSBoxes(boxes, output[:, :, :, 2], ctpn_params['text_proposal_min_score'], ctpn_params['text_proposal_nms_threshold'])
# 提取文本行图像
lines = []
for i in boxes:
i = i[0]
x1, y1, x2, y2 = boxes[i]
line = gray[y1:y2, x1:x2]
lines.append(line)
return lines
# 定义文本识别函数
def recognize_text(image):
# 缩放图像
img_w, img_h = crnn_params['img_w'], crnn_params['img_h']
image = cv2.resize(image, (img_w, img_h))
# 归一化图像
if np.max(image) > 1:
image = image / 255
# 转换图像格式
image = image.transpose([1, 0, 2])
image = np.expand_dims(image, axis=0)
# 预测文本
y_pred = crnn_model.predict(image)
y_pred = np.argmax(y_pred, axis=2)[0]
# 将预测结果转换为文本
text = ''
for i in y_pred:
if i != len(charset) - 1 and (not (len(text) > 0 and text[-1] == charset[i])):
text += charset[i]
return text
# 读取图像
image = cv2.imread('test.png')
# 检测文本行
lines = detect_text(image)
# 识别文本
texts = []
for line in lines:
text = recognize_text(line)
texts.append(text)
# 输出识别结果
print(texts)
```
上述代码实现了一个基于ctpn和crnn的OCR系统,其中ctpn用于检测文本行,crnn用于识别文本内容。在使用代码时,需要将ctpn和crnn的模型文件替换为自己训练的模型文件,并根据实际情况调整参数。