加载已径训练好的模型(使用.pb格式),实现手写数字识别界面的设计,多一些功能的界面比如可以识别单张图片单个数字和单张图片多个数字的识别,显示出完整的代码
时间: 2023-12-11 22:04:59 浏览: 200
基于matlab平台实现的数字识别,增加的GUI界面可以直接手写数字进行识别.zip
首先,我们需要加载已经训练好的模型文件,这里我们选择使用 TensorFlow 的模型加载方式。在加载模型之前,我们需要先定义输入和输出节点的名称,这个可以在训练模型时指定。
```python
import tensorflow as tf
# 定义输入和输出节点的名称
input_tensor_name = 'input:0'
output_tensor_name = 'output:0'
# 加载模型文件
model_path = 'model.pb'
with tf.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 导入模型
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
# 获取输入和输出节点
input_tensor = graph.get_tensor_by_name(input_tensor_name)
output_tensor = graph.get_tensor_by_name(output_tensor_name)
```
接下来,我们设计一个简单的图形界面,使用 tkinter 库实现。在界面中,我们可以选择“手写数字输入”或者“图片输入”两种方式进行识别。对于“手写数字输入”,我们提供了一个小画板,可以让用户手写数字。对于“图片输入”,我们可以从本地文件中选择一张图片进行识别。
```python
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
# 创建主窗口
root = tk.Tk()
root.title('手写数字识别')
# 创建画板
canvas = tk.Canvas(root, width=200, height=200, bg='white')
canvas.pack()
# 创建工具栏
toolbar = tk.Frame(root)
toolbar.pack(side=tk.TOP)
# 创建“手写数字输入”按钮
button_draw = tk.Button(toolbar, text='手写数字输入')
button_draw.pack(side=tk.LEFT)
# 创建“图片输入”按钮
button_image = tk.Button(toolbar, text='图片输入')
button_image.pack(side=tk.LEFT)
# 创建状态栏
statusbar = tk.Label(root, text='请在画板中手写数字', bd=1, relief=tk.SUNKEN, anchor=tk.W)
statusbar.pack(side=tk.BOTTOM, fill=tk.X)
# 显示主窗口
root.mainloop()
```
接下来,我们需要实现“手写数字输入”和“图片输入”两种方式的识别功能。对于“手写数字输入”,我们可以使用鼠标事件来实现用户在画板上手写数字的过程,并将得到的图像数据传入模型进行识别。对于“图片输入”,我们可以使用文件对话框来让用户选择一张图片文件,并将其转换成模型可以处理的格式后进行识别。
```python
import numpy as np
# 定义手写数字输入的事件处理函数
def draw(event):
x, y = event.x, event.y
r = 8
canvas.create_oval(x-r, y-r, x+r, y+r, fill='black')
image = Image.new('L', (200, 200), 255)
draw = ImageDraw.Draw(image)
draw.ellipse((x-r, y-r, x+r, y+r), fill=0)
data = np.array(image).reshape(-1, 28, 28, 1)
result = session.run(output_tensor, feed_dict={input_tensor: data})
statusbar.config(text='识别结果为:%d' % result[0])
# 定义图片输入的事件处理函数
def choose_image():
# 打开文件对话框选择图片文件
file_path = filedialog.askopenfilename(filetypes=[('Image Files', '*.png;*.jpg;*.jpeg')])
if file_path:
# 加载图片并进行预处理
image = Image.open(file_path).convert('L').resize((28, 28), Image.ANTIALIAS)
data = np.array(image).reshape(-1, 28, 28, 1)
# 进行识别并显示结果
result = session.run(output_tensor, feed_dict={input_tensor: data})
statusbar.config(text='识别结果为:%d' % result[0])
# 显示图片
image = ImageTk.PhotoImage(image)
canvas.create_image(100, 100, image=image)
canvas.image = image
# 绑定鼠标事件
canvas.bind('<B1-Motion>', draw)
# 绑定按钮事件
button_draw.config(command=lambda: statusbar.config(text='请在画板中手写数字'))
button_image.config(command=choose_image)
```
最后,我们需要在界面中添加更多的功能,比如支持单张图片多个数字的识别。具体实现可以根据需要进行扩展。
完整代码如下:
```python
import tensorflow as tf
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageDraw, ImageTk
import numpy as np
# 定义输入和输出节点的名称
input_tensor_name = 'input:0'
output_tensor_name = 'output:0'
# 加载模型文件
model_path = 'model.pb'
with tf.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 导入模型
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
# 获取输入和输出节点
input_tensor = graph.get_tensor_by_name(input_tensor_name)
output_tensor = graph.get_tensor_by_name(output_tensor_name)
# 创建主窗口
root = tk.Tk()
root.title('手写数字识别')
# 创建画板
canvas = tk.Canvas(root, width=200, height=200, bg='white')
canvas.pack()
# 创建工具栏
toolbar = tk.Frame(root)
toolbar.pack(side=tk.TOP)
# 创建“手写数字输入”按钮
button_draw = tk.Button(toolbar, text='手写数字输入')
button_draw.pack(side=tk.LEFT)
# 创建“图片输入”按钮
button_image = tk.Button(toolbar, text='图片输入')
button_image.pack(side=tk.LEFT)
# 创建状态栏
statusbar = tk.Label(root, text='请在画板中手写数字', bd=1, relief=tk.SUNKEN, anchor=tk.W)
statusbar.pack(side=tk.BOTTOM, fill=tk.X)
# 定义手写数字输入的事件处理函数
def draw(event):
x, y = event.x, event.y
r = 8
canvas.create_oval(x-r, y-r, x+r, y+r, fill='black')
image = Image.new('L', (200, 200), 255)
draw = ImageDraw.Draw(image)
draw.ellipse((x-r, y-r, x+r, y+r), fill=0)
data = np.array(image).reshape(-1, 28, 28, 1)
result = session.run(output_tensor, feed_dict={input_tensor: data})
statusbar.config(text='识别结果为:%d' % result[0])
# 定义图片输入的事件处理函数
def choose_image():
# 打开文件对话框选择图片文件
file_path = filedialog.askopenfilename(filetypes=[('Image Files', '*.png;*.jpg;*.jpeg')])
if file_path:
# 加载图片并进行预处理
image = Image.open(file_path).convert('L').resize((28, 28), Image.ANTIALIAS)
data = np.array(image).reshape(-1, 28, 28, 1)
# 进行识别并显示结果
result = session.run(output_tensor, feed_dict={input_tensor: data})
statusbar.config(text='识别结果为:%d' % result[0])
# 显示图片
image = ImageTk.PhotoImage(image)
canvas.create_image(100, 100, image=image)
canvas.image = image
# 绑定鼠标事件
canvas.bind('<B1-Motion>', draw)
# 绑定按钮事件
button_draw.config(command=lambda: statusbar.config(text='请在画板中手写数字'))
button_image.config(command=choose_image)
# 创建会话
session = tf.Session(graph=graph)
# 运行主窗口
root.mainloop()
# 关闭会话
session.close()
```
阅读全文