# 创建插值数据点 interp_x = np.linspace(0, len(train_losses) - 1, 1000) # 训练损失曲线 interp_train_losses = np.interp(interp_x, range(len(train_losses)), train_losses) plt.plot(interp_x, interp_train_losses, label='Train Loss') # 测试损失曲线 interp_test_losses = np.interp(interp_x, range(len(test_losses)), test_losses) plt.plot(interp_x, interp_test_losses, label='Test Loss') # 训练准确率曲线 interp_train_accs = np.interp(interp_x, range(len(train_accs)), train_accs) plt.plot(interp_x, interp_train_accs, label='Train Acc') # 测试准确率曲线 interp_test_accs = np.interp(interp_x, range(len(test_accs)), test_accs) plt.plot(interp_x, interp_test_accs, label='Test Acc') # 设置y轴范围为0~1 plt.ylim([0, 1]) # 显示图例 plt.legend() # 显示图像 plt.show() 怎样使训练集的两条曲线在一张图上,测试集的两条曲线在另一张图上?
时间: 2023-12-11 11:03:51 浏览: 151
要将训练集的两条曲线绘制在同一张图上,测试集的两条曲线绘制在另一张图上,可以使用 `matplotlib.pyplot.subplots` 函数创建两个子图,并在每个子图上绘制相应的曲线。
以下是一个示例代码,演示如何绘制训练集和测试集的损失函数和准确率曲线:
```
import numpy as np
import matplotlib.pyplot as plt
# 假设有原始的训练损失、测试损失、训练准确率、测试准确率数据
train_losses = [0.5, 0.4, 0.3, 0.25, 0.2, 0.15, 0.1, 0.05]
test_losses = [0.6, 0.5, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15]
train_accs = [0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99]
test_accs = [0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
# 创建插值数据点
interp_x = np.linspace(0, len(train_losses) - 1, 1000)
# 创建第一个子图,绘制训练集损失和准确率曲线
fig, ax1 = plt.subplots()
ax1.plot(train_losses, label='Train Loss')
ax1.plot(train_accs, label='Train Acc')
ax1.set_ylim([0, 1])
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss/Accuracy')
ax1.legend()
# 创建第二个子图,绘制测试集损失和准确率曲线
fig, ax2 = plt.subplots()
ax2.plot(test_losses, label='Test Loss')
ax2.plot(test_accs, label='Test Acc')
ax2.set_ylim([0, 1])
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss/Accuracy')
ax2.legend()
# 显示图像
plt.show()
```
在这个示例中,我们首先使用 `numpy.linspace` 函数创建一个包含1000个插值数据点的数组 `interp_x`,然后使用 `matplotlib.pyplot.subplots` 函数创建两个子图。在第一个子图上,我们使用 `matplotlib.pyplot.plot` 函数绘制训练集的损失和准确率曲线,并使用 `matplotlib.axes.Axes.set_xlabel` 和 `matplotlib.axes.Axes.set_ylabel` 函数设置坐标轴标签。在第二个子图上,我们绘制测试集的损失和准确率曲线,并设置坐标轴标签。最后,我们使用 `matplotlib.pyplot.show` 函数显示图像。
阅读全文