def predict(): TEST_DIR = 'D:/final/CatVsDog-master/media/img/' result = [] # model = load_model('my_model.h5') model = load_model('D:/final/CatVsDog-master/venv/Include/VGG/model.h5') test_images = [TEST_DIR + i for i in os.listdir(TEST_DIR)] count = len(test_images) # data = np.ndarray((count, CHANNELS, ROWS, COLS), dtype=np.uint8) data = np.ndarray((count, ROWS, COLS, CHANNELS), dtype=np.uint8) # print("图片网维度:") print(data.shape) for i, image_file in enumerate(test_images): image = read_image(image_file) # print() data[i] = image # data[i] = image.T if i % 250 == 0: print('处理 {} of {}'.format(i, count)) test = data predictions = model.predict(test, verbose=0) dict = {} urls = [] for i in test_images: ss = i.split('/') url = '/' + ss[3] + '/' + ss[4] + '/' + ss[5] urls.append(url) for i in range(0, len(predictions)): if predictions[i, 0] >= 0.5: print('I am {:.2%} sure this is a Dog'.format(predictions[i][0])) dict[urls[i]] = "图片预测为:关!" else: print('I am {:.2%} sure this is a Cat'.format(1 - predictions[i][0])) dict[urls[i]] = "图片预测为:开!" plt.imshow(test[i]) # plt.imshow(test[i].T) plt.show() # time.sleep(2) # print(dict) # for key,value in dict.items(): # print(key + ':' + value) return dict
时间: 2023-06-14 09:05:01 浏览: 126
这是一个使用VGG模型进行图片分类的函数。它首先加载了预训练的VGG模型,然后遍历指定文件夹中的所有图片,读入图片并将其转换为numpy数组。然后,它使用模型对每个图像进行预测,并将结果存储在一个字典中,其中键是图片路径,值是预测结果(是狗还是猫)。最后,它将每个图像和其预测结果显示在屏幕上,并返回一个字典,其中包含每个图像的预测结果。
相关问题
ROWS = 150 COLS = 150 # # ROWS = 128 # COLS = 128 CHANNELS = 3 def read_image(file_path): img = cv2.imread(file_path, cv2.IMREAD_COLOR) return cv2.resize(img, (ROWS, COLS), interpolation=cv2.INTER_CUBIC) def predict(): TEST_DIR = 'D:/final/CatVsDog-master/media/img/' result = [] # model = load_model('my_model.h5') model = load_model('D:/final/CatVsDog-master/venv/Include/VGG/model.h5') test_images = [TEST_DIR + i for i in os.listdir(TEST_DIR)] count = len(test_images) # data = np.ndarray((count, CHANNELS, ROWS, COLS), dtype=np.uint8) data = np.ndarray((count, ROWS, COLS, CHANNELS), dtype=np.uint8) # print("图片网维度:") print(data.shape) for i, image_file in enumerate(test_images): image = read_image(image_file) # print() data[i] = image # data[i] = image.T if i % 250 == 0: print('处理 {} of {}'.format(i, count)) test = data predictions = model.predict(test, verbose=0) dict = {} urls = [] for i in test_images: ss = i.split('/') url = '/' + ss[3] + '/' + ss[4] + '/' + ss[5] urls.append(url) for i in range(0, len(predictions)): if predictions[i, 0] >= 0.5: print('I am {:.2%} sure this is a Dog'.format(predictions[i][0])) dict[urls[i]] = "图片预测为:关!" else: print('I am {:.2%} sure this is a Cat'.format(1 - predictions[i][0])) dict[urls[i]] = "图片预测为:开!" plt.imshow(test[i]) # plt.imshow(test[i].T) plt.show() # time.sleep(2) # print(dict) # for key,value in dict.items(): # print(key + ':' + value) return dict if __name__ == '__main__': result = predict() for i in result: print(i)
这段代码是一个使用 VGG 模型进行图片分类的程序。代码中首先定义了图片的尺寸和通道数,然后定义了一个读取图片并进行缩放的函数 read_image。接着定义了一个预测函数 predict,其中加载了之前训练好的 VGG 模型并对测试图片进行预测。最后将预测结果存储在字典中并返回。在主函数中调用了 predict 函数并输出预测结果。
阅读全文