def plot_image(i, predictions_array, true_label, img): predictions_array, true_label, img = predictions_array, true_label[i], img[i] plt.grid(False) plt.xticks([]) plt.yticks([]) plt.imshow(img, cmap=plt.cm.binary) predicted_label = np.argmax(predictions_array) if predicted_label == true_label: color = 'blue' else: color = 'red' plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label], 100 * np.max(predictions_array), class_names[true_label]), color=color) def plot_value_array(i, predictions_array, true_label): predictions_array, true_label = predictions_array, true_label[i] plt.grid(False) plt.xticks(range(10)) plt.yticks([]) thisplot = plt.bar(range(10), predictions_array, color="#777777") plt.ylim([0, 1]) predicted_label = np.argmax(predictions_array) thisplot[predicted_label].set_color('red') thisplot[true_label].set_color('blue') print("验证预测结果:") i = 12 plt.figure(figsize=(6, 3)) plt.subplot(1, 2, 1) plot_image(i, predictions[i], test_labels, test_images) plt.subplot(1, 2, 2) plot_value_array(i, predictions[i], test_labels) plt.show() num_rows = 5 num_cols = 3 num_images = num_rows * num_cols plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows)) for i in range(num_images): plt.subplot(num_rows, 2 * num_cols, 2 * i + 1) plot_image(i, predictions[i], test_labels, test_images) plt.subplot(num_rows, 2 * num_cols, 2 * i + 2) plot_value_array(i, predictions[i], test_labels) plt.tight_layout() plt.show() # 使用训练好的模型对单个图像进行预测 img = test_images[1] print(img.shape) # tf.keras 模型经过了优化,可同时对一个批或一组样本进行预测 img = (np.expand_dims(img, 0)) print(img.shape) # 增加相应标签 predictions_single = probability_model.predict(img) print(predictions_single) plot_value_array(1, predictions_single[0], test_labels) _ = plt.xticks(range(10), class_names, rotation=45)
时间: 2024-03-28 21:39:33 浏览: 98
这段代码的作用是用来展示模型对测试数据集的预测结果。首先定义了两个函数plot_image和plot_value_array,用于可视化展示预测结果。然后调用这两个函数来展示第12张测试图片的预测结果和对应的概率分布情况。接着定义了num_rows、num_cols和num_images三个变量,分别表示展示预测结果的行数、列数和总图片数。最后使用for循环展示num_images张测试图片的预测结果和对应的概率分布情况。其中,使用了tight_layout函数来调整子图之间的间距,使得展示效果更加美观。最后使用训练好的模型对单个图片进行了预测,并展示了预测结果和对应的概率分布情况。
相关问题
def plot_image(i, predictions_array, true_label, img): predictions_array, true_label, img = predictions_array, true_label[i], img[i] plt.grid(False) plt.xticks([]) plt.yticks([])
A:这段代码是一个函数,用于显示图像及其对应的标签和预测结果。参数包括:图像的索引 i,预测数组 predictions_array,真实标签 true_label 和图像 img。函数会将 predictions_array、true_label 和 img 中与 i 相对应的值取出,然后使用 matplotlib 中的函数显示图像,并取消图像的网格和坐标轴标签。
plt.rcParams['font.sans-serif']=['Arial Unicode MS'] #显示中文字体,这段代码我可是找了好长时间 plt.rcParams['axes.unicode_minus']=False def plot_image(i, predictions_array, true_labels, images): predictions_array, true_label, img = predictions_array[i], true_labels[i], images[i] plt.grid(False) plt.xticks([]) plt.yticks([]) plt.grid(False) # 显示照片,以cm 为单位。 plt.imshow(images[i], cmap=plt.cm.binary) # 预测的图片是否正确,黑色底表示预测正确,红色底表示预测失败 predicted_label = np.argmax(prediction[i]) true_label = y_test[i][0] if predicted_label == true_label: color = 'black' else: color = 'red' # plt.xlabel("{} ({})".format(class_names[predicted_label], # class_names[true_label]), # color=color) plt.xlabel("预测{:2.0f}%是{}(实际{})".format(100*np.max(predictions_array), class_names[predicted_label], class_names[true_label]), color=color) def plot_value_array(i, predictions_array, true_label): predictions_array, true_label = predictions_array[i], true_label[i][0] plt.grid(False) plt.xticks(range(10)) plt.yticks([]) thisplot = plt.bar(range(10), predictions_array, color="#777777") plt.ylim([0, 1]) predicted_label = np.argmax(predictions_array) thisplot[predicted_label].set_color('red') thisplot[true_label].set_color('blue') num_rows = 5 num_cols = 3 num_images = num_rows*num_cols plt.figure(figsize=(2*2*num_cols, 2*num_rows)) for i in range(num_images): plt.subplot(num_rows, 2*num_cols, 2*i+1) plot_image(i, prediction, y_test, x_test) plt.subplot(num_rows, 2*num_cols, 2*i+2) plot_value_array(i, prediction, y_test)解释如下代码
这段代码是一个用于绘制图像和数值数组的函数。首先,通过设置字体和解决中文显示问题的代码来准备绘图环境。然后,定义了两个函数:`plot_image()` 和 `plot_value_array()`。`plot_image()`函数用于绘制图像,并根据预测结果的正确性将其标记为黑色或红色。`plot_value_array()`函数用于绘制数值数组,并将预测结果的最大值标记为红色,将真实标签标记为蓝色。接下来,通过循环调用这两个函数来绘制一组图像和相应的数值数组。最终,使用`plt.figure()`函数定义了一个绘图窗口,并在其中调用这两个函数来绘制图像和数值数组。
阅读全文
相关推荐

















