用jupyter的tensorflow做一个准确度高的数字图像识别自己的数字图片代码
时间: 2024-02-21 13:57:01 浏览: 63
好的,下面是一个基于TensorFlow和CNN的数字图像识别代码,可以在Jupyter Notebook中执行。这是一个示例代码,需要你根据自己的数据集进行修改和调整。
```python
# 导入必要的库
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
# 设置参数
batch_size = 32
img_height = 28
img_width = 28
# 加载数据集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
'data/train',
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
'data/train',
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
num_classes = len(class_names)
# 数据预处理
data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(0.1),
layers.experimental.preprocessing.RandomZoom(0.1),
]
)
# 构建模型
model = keras.Sequential([
data_augmentation,
layers.experimental.preprocessing.Rescaling(1./255),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
# 编译模型
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型
epochs=10
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
# 评估模型
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
'data/test',
image_size=(img_height, img_width),
batch_size=batch_size)
model.evaluate(test_ds)
# 预测数字
def predict_digit(image_path):
img = Image.open(image_path)
img = img.resize((img_height, img_width))
img = np.array(img)
img = img / 255.0
img = np.expand_dims(img, axis=0)
prediction = model.predict(img)
return class_names[np.argmax(prediction)]
# 测试预测函数
print(predict_digit('data/test/9/1.png')) # 预测结果应该是 "9"
```
在执行代码之前,你需要将你的手写数字图片数据集放在以下目录结构中(假设你的数据集名称为`my_digits`):
```
my_digits/
├── train/
│ ├── 0/
│ ├── 1/
│ ├── 2/
│ ├── ...
├── test/
│ ├── 0/
│ ├── 1/
│ ├── 2/
│ ├── ...
```
其中,`train`目录包含用于训练模型的手写数字图片,`test`目录包含用于评估模型和预测数字的手写数字图片。每个数字应该有一个独立的子目录,子目录名为数字的标签(例如`0`,`1`,`2`等等),每个子目录应该包含对应的手写数字图片。
你需要修改代码中的`train_ds`、`val_ds`和`test_ds`路径,以及`predict_digit()`函数中的图片路径,来适应你自己的数据集。
阅读全文