cnn算法如何实现猫狗分类,Python代码实现
时间: 2024-05-04 19:21:25 浏览: 260
下面是使用CNN实现猫狗分类的Python代码实现:
```python
# 导入必要的库
import os
import cv2
import numpy as np
from random import shuffle
from tqdm import tqdm
TRAIN_DIR = 'train' # 训练集路径
TEST_DIR = 'test' # 测试集路径
IMG_SIZE = 50 # 图片大小
LR = 1e-3 # 学习率
MODEL_NAME = 'dogsvscats-{}-{}.model'.format(LR, '2conv-basic') # 模型名称
# 狗和猫的标签
def label_img(img):
word_label = img.split('.')[-3]
if word_label == 'cat': return [1,0]
elif word_label == 'dog': return [0,1]
# 加载训练集数据
def create_train_data():
training_data = []
for img in tqdm(os.listdir(TRAIN_DIR)):
label = label_img(img)
path = os.path.join(TRAIN_DIR,img)
img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
training_data.append([np.array(img),np.array(label)])
shuffle(training_data)
np.save('train_data.npy', training_data)
return training_data
# 加载测试集数据
def process_test_data():
testing_data = []
for img in tqdm(os.listdir(TEST_DIR)):
path = os.path.join(TEST_DIR,img)
img_num = img.split('.')[0]
img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
testing_data.append([np.array(img), img_num])
np.save('test_data.npy', testing_data)
return testing_data
# 构建CNN模型
def cnn_model():
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(IMG_SIZE, IMG_SIZE, 1)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(2))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=LR),
metrics=['accuracy'])
return model
# 训练模型
def train_model():
train_data = np.load('train_data.npy', allow_pickle=True)
train = train_data[:-500]
test = train_data[-500:]
X = np.array([i[0] for i in train]).reshape(-1,IMG_SIZE,IMG_SIZE,1)
y = [i[1] for i in train]
test_x = np.array([i[0] for i in test]).reshape(-1,IMG_SIZE,IMG_SIZE,1)
test_y = [i[1] for i in test]
model = cnn_model()
model.fit(X, y, batch_size=32, epochs=10, validation_data=(test_x, test_y))
model.save(MODEL_NAME)
return model
# 预测结果
def predict(model):
test_data = np.load('test_data.npy', allow_pickle=True)
fig = plt.figure()
for num,data in enumerate(test_data[:12]):
img_num = data[1]
img_data = data[0]
y = fig.add_subplot(3,4,num+1)
orig = img_data
data = img_data.reshape(IMG_SIZE,IMG_SIZE,1)
model_out = model.predict([data])[0]
if np.argmax(model_out) == 1: str_label='Dog'
else: str_label='Cat'
y.imshow(orig,cmap='gray')
plt.title(str_label)
y.axes.get_xaxis().set_visible(False)
y.axes.get_yaxis().set_visible(False)
plt.show()
# 主函数
def main():
if not os.path.exists('train_data.npy'):
train_data = create_train_data()
else:
train_data = np.load('train_data.npy', allow_pickle=True)
if not os.path.exists('test_data.npy'):
test_data = process_test_data()
else:
test_data = np.load('test_data.npy', allow_pickle=True)
model = train_model()
predict(model)
main()
```
这个代码实现了一个简单的CNN模型,对猫和狗的图片进行分类。你可以根据自己的需要进行修改和扩展。
阅读全文