ROWS = 150 COLS = 150 # # ROWS = 128 # COLS = 128 CHANNELS = 3 def read_image(file_path): img = cv2.imread(file_path, cv2.IMREAD_COLOR) return cv2.resize(img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC) def predict(): TEST_DIR = 'D:/final/CatVsDog-master/media/img/' result = [] # model = load_model('my_model.h5') model = load_model('D:/final/CatVsDog-master/venv/Include/VGG/model.h5') test_images = [TEST_DIR + i for i in os.listdir(TEST_DIR)] count = len(test_images) # data = np.ndarray((count, CHANNELS, ROWS, COLS), dtype=np.uint8) data = np.ndarray((count, ROWS, COLS, CHANNELS), dtype=np.uint8) # print("图片网维度:") print(data.shape) for i, image_file in enumerate(test_images): image = read_image(image_file) # print() data[i] = image # data[i] = image.T if i % 250 == 0: print('处理 {} of {}'.format(i, count)) test = data predictions = model.predict(test, verbose=0) dict = {} urls = [] for i in test_images: ss = i.split('/') url = '/' + ss[3] + '/' + ss[4] + '/' + ss[5] urls.append(url) for i in range(0, len(predictions)): if predictions[i, 0] >= 0.5: print('I am {:.2%} sure this is a Dog'.format(predictions[i][0])) dict[urls[i]] = "图片预测为:关!" else: print('I am {:.2%} sure this is a Cat'.format(1 - predictions[i][0])) dict[urls[i]] = "图片预测为:开!" plt.imshow(test[i]) # plt.imshow(test[i].T) plt.show() # time.sleep(2) # print(dict) # for key,value in dict.items(): # print(key + ':' + value) return dict if __name__ == '__main__': result = predict() for i in result: print(i)
时间: 2023-06-14 13:05:04 浏览: 52
这段代码是一个使用 VGG 模型进行图片分类的程序。代码中首先定义了图片的尺寸和通道数,然后定义了一个读取图片并进行缩放的函数 read_image。接着定义了一个预测函数 predict,其中加载了之前训练好的 VGG 模型并对测试图片进行预测。最后将预测结果存储在字典中并返回。在主函数中调用了 predict 函数并输出预测结果。
相关问题
rows, cols = img_noise.shape
这行代码是将变量 `img_noise` 的形状(shape)元组(tuple)中的两个值分别赋值给变量 `rows` 和 `cols`。在这个代码中,`img_noise` 可能是一个带有噪声的图像,`img_noise.shape` 返回一个形状元组,包含了图像的行数、列数以及通道数等信息。由于这里只对灰度图像进行处理,因此只需要取前两个值,分别赋值给 `rows` 和 `cols`。这样在后续的代码中就可以使用这两个变量来处理图像了。
roi = img1[0:rows,(cols1-cols):cols1]
这段代码使用了Python中的切片操作,用于提取图像img1中的一个感兴趣区域(Region of Interest,ROI)。具体来说,它提取了img1的第一行到最后一行,以及从(cols1-cols)列到cols1列的区域。
这样做的目的可能是为了获取img1中右侧一部分的感兴趣区域。提取后的ROI将存储在名为roi的变量中。