没有合适的资源?快使用搜索试试~ 我知道了~
首页在tensorflow下利用plt画论文中loss,acc等曲线图实例
直接上代码: fig_loss = np.zeros([n_epoch]) fig_acc1 = np.zeros([n_epoch]) fig_acc2= np.zeros([n_epoch]) for epoch in range(n_epoch): start_time = time.time() #training train_loss, train_acc, n_batch = 0, 0, 0 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True): _,er
资源详情
资源评论
资源推荐

在在tensorflow下利用下利用plt画论文中画论文中loss,acc等曲线图实例等曲线图实例
直接上代码:
fig_loss = np.zeros([n_epoch])
fig_acc1 = np.zeros([n_epoch])
fig_acc2= np.zeros([n_epoch])
for epoch in range(n_epoch):
start_time = time.time()
#training
train_loss, train_acc, n_batch = 0, 0, 0
for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
_,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
train_loss += err; train_acc += ac; n_batch += 1
summary_str = sess.run(merged_summary_op,feed_dict={x: x_train_a, y_: y_train_a})
summary_writer.add_summary(summary_str, epoch)
print(" train loss: %f" % (np.sum(train_loss)/ n_batch))
print(" train acc: %f" % (np.sum(train_acc)/ n_batch))
fig_loss[epoch] = np.sum(train_loss)/ n_batch
fig_acc1[epoch] = np.sum(train_acc) / n_batch
#validation
val_loss, val_acc, n_batch = 0, 0, 0
for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
val_loss += err; val_acc += ac; n_batch += 1
print(" validation loss: %f" % (np.sum(val_loss)/ n_batch))
print(" validation acc: %f" % (np.sum(val_acc)/ n_batch))
fig_acc2[epoch] = np.sum(val_acc) / n_batch
# 训练loss图
fig, ax1 = plt.subplots()
lns1 = ax1.plot(np.arange(n_epoch), fig_loss, label="Loss")
ax1.set_xlabel('iteration')
ax1.set_ylabel('training loss')
# 训练和验证两种准确率曲线图放在一张图中
fig2, ax2 = plt.subplots()
ax3 = ax2.twinx()#由ax2图生成ax3图
lns2 = ax2.plot(np.arange(n_epoch), fig_acc1, label="Loss")
lns3 = ax3.plot(np.arange(n_epoch), fig_acc2, label="Loss")
ax2.set_xlabel('iteration')
ax2.set_ylabel('training acc')
ax3.set_ylabel('val acc')
# 合并图例
lns = lns3 + lns2
labels = ["train acc", "val acc"] plt.legend(lns, labels, loc=7)
plt.show()
结果:















安全验证
文档复制为VIP权益,开通VIP直接复制

评论0