使用keras简单的猫狗识别
时间: 2023-11-01 20:07:11 浏览: 32
好的,以下是使用Keras进行猫狗识别的简单代码。
首先,我们需要准备数据集。可以从Kaggle上下载猫狗分类的数据集,地址为 https://www.kaggle.com/c/dogs-vs-cats/data。
下载完后,我们需要解压缩数据集,并将训练集和测试集分别放到两个文件夹中,用于训练和测试模型。
接下来,我们使用Keras搭建模型。这里我们使用卷积神经网络(CNN)进行训练和测试。代码如下:
```python
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 定义模型
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(150, 150, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
# 训练模型
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'train',
target_size=(150, 150),
batch_size=20,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'test',
target_size=(150, 150),
batch_size=20,
class_mode='binary')
model.fit_generator(
train_generator,
steps_per_epoch=100,
epochs=30,
validation_data=validation_generator,
validation_steps=50)
# 保存模型
model.save('cat_dog_cnn.h5')
```
上面的代码中,我们使用了四个卷积层和四个池化层,最后使用一个全连接层和一个输出层。我们使用了ImageDataGenerator对图像进行预处理和增强,然后使用fit_generator方法对模型进行训练。最后,我们将训练好的模型保存到cat_dog_cnn.h5文件中。
接下来,我们可以使用训练好的模型进行猫狗识别。代码如下:
```python
from keras.models import load_model
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
# 加载模型
model = load_model('cat_dog_cnn.h5')
# 加载图像
img = load_img('test/cat.1.jpg', target_size=(150, 150))
x = img_to_array(img)
x = np.expand_dims(x, axis=0)
x = x / 255.0
# 预测图像
prediction = model.predict(x)
if prediction[0][0] < 0.5:
print('猫')
else:
print('狗')
```
上面的代码中,我们首先加载训练好的模型,然后加载测试图像,并进行预处理。最后使用predict方法对图像进行预测,输出预测结果。