给我写一份基于cnn卷积神经网络的水果识别python文件
时间: 2023-12-27 07:04:53 浏览: 105
好的,下面是一个简单的基于CNN的水果识别Python文件示例:
```python
# 导入必要的库
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, Dropout
import cv2
import os
# 定义图像大小和类别数
IMG_HEIGHT = 100
IMG_WIDTH = 100
NUM_CLASSES = 5
# 加载数据集
def load_data():
data = []
labels = []
for fruit_class in os.listdir("fruits/"):
for image_name in os.listdir("fruits/" + fruit_class):
image = cv2.imread("fruits/" + fruit_class + "/" + image_name)
image = cv2.resize(image, (IMG_HEIGHT, IMG_WIDTH))
data.append(image)
labels.append(fruit_class)
data = np.array(data)
labels = pd.get_dummies(labels).values
return data, labels
# 数据预处理
def preprocess_data(data, labels):
data = data.astype('float32') / 255.0
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)
y_train = to_categorical(y_train, NUM_CLASSES)
y_test = to_categorical(y_test, NUM_CLASSES)
return x_train, x_test, y_train, y_test
# 定义CNN模型
def create_model():
model = Sequential([
Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(64, kernel_size=(3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(NUM_CLASSES, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
# 加载数据集
data, labels = load_data()
# 数据预处理
x_train, x_test, y_train, y_test = preprocess_data(data, labels)
# 定义模型并进行训练
model = create_model()
model.fit(x_train, y_train, batch_size=32, epochs=10, validation_data=(x_test, y_test))
# 保存模型
model.save("fruit_classification_cnn.h5")
```
这份代码实现了基于CNN的水果分类任务。具体来说,它首先加载数据集,然后对图像进行预处理,接着定义CNN模型并进行训练。最后,将训练好的模型保存到文件中。您可以根据自己的需求进行修改和扩展。
阅读全文