用jupyter写一个简单的图片识别界面
时间: 2023-06-16 15:03:40 浏览: 151
首先需要安装所需的库:`tensorflow`, `keras`, `numpy`, `matplotlib` 和 `ipywidgets`。
可以使用以下命令来安装这些库:
```
!pip install tensorflow
!pip install keras
!pip install numpy
!pip install matplotlib
!pip install ipywidgets
```
接下来,我们可以创建一个Jupyter Notebook,并按以下步骤进行操作:
1. 导入必要的库
```python
import tensorflow as tf
from keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display
```
2. 加载模型
```python
model = tf.keras.models.load_model('model.h5')
```
此处的 `model.h5` 是我们训练好的模型文件。如果你没有模型文件,可以使用以下代码来训练一个简单的模型:
```python
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28,28,1)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# load data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) / 255.0
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1) / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
# train model
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=5, batch_size=64)
# save model
model.save('model.h5')
```
3. 创建GUI界面
```python
# create GUI
image_upload = widgets.FileUpload()
output_image = widgets.Output()
output_label = widgets.Output()
button_predict = widgets.Button(description='Predict')
def on_button_predict_clicked(b):
# get uploaded image
image_string = image_upload.data[-1]
image_bytes = image_string['content']
img = image.img_to_array(image.load_img(io.BytesIO(image_bytes), target_size=(28, 28), color_mode='grayscale')) / 255.0
# predict label
with output_label:
output_label.clear_output()
predictions = model.predict(np.array([img]))
label = np.argmax(predictions)
plt.imshow(img.reshape(28,28), cmap='gray')
plt.title('Predicted Label: {}'.format(label))
plt.axis('off')
plt.show()
button_predict.on_click(on_button_predict_clicked)
display(image_upload, button_predict, output_image, output_label)
```
这将创建一个界面,其中包含一个文件上传按钮和一个预测按钮。当用户上传一张图片并点击预测按钮时,该图片将被加载并传递给模型进行预测,预测结果将显示在界面上。
阅读全文