def plot_learning_curve(loss_record, title=''): ''' Plot learning curve of your DNN (train & dev loss) ''' total_steps = len(loss_record['train']) x_1 = range(total_steps) x_2 = x_1[::len(loss_record['train']) // len(loss_record['dev'])] figure(figsize=(6, 4)) plt.plot(x_1, loss_record['train'], c='tab:red', label='train') plt.plot(x_2, loss_record['dev'], c='tab:cyan', label='dev') plt.ylim(0.0, 5.) plt.xlabel('Training steps') plt.ylabel('MSE loss') plt.title('Learning curve of {}'.format(title)) plt.legend() plt.show()
时间: 2023-09-11 12:08:48 浏览: 119
plot_cv_predict.zip_cross_val_predict_plot(cv._predict-python_预测
5星 · 资源好评率100%
这段代码用于绘制学习曲线,展示训练集和开发集的损失情况。函数中的输入参数为loss_record,它是一个字典,包含了训练集和开发集的损失记录。
首先,通过获取训练集的总步数,计算出用于绘制x轴的坐标点x_1。然后,根据开发集的记录数量和训练集的总步数,计算出用于绘制开发集损失的x轴坐标点x_2。
接下来,创建一个图形窗口,并使用plt.plot()函数分别绘制训练集和开发集的损失曲线。训练集的曲线使用红色表示(c='tab:red'),开发集的曲线使用青色表示(c='tab:cyan')。同时,使用plt.ylim()函数设置y轴的范围为0.0到5.0,使用plt.xlabel()和plt.ylabel()函数设置x轴和y轴的标签,使用plt.title()函数设置图形的标题。
最后,使用plt.legend()函数显示图例,并使用plt.show()函数显示绘制的学习曲线图形。
阅读全文