#build live matplotlib fig fig = plt.figure() ax = fig.add_subplot(111) plt.ion() fig.show() fig.canvas.draw() for e in range(epochs): out = [] for b in range(-10- enc_seq_len, 10 - enc_seq_len): optimizer.zero_grad() X, Y = get_data(batch_size, enc_seq_len, output_sequence_length) #Forward pass and calculate loss net_out = t(X) #print(net_out.shape,Y.shape) loss = torch.mean((net_out - Y) ** 2) #backwards pass loss.backward() optimizer.step() #Track losses and draw rgaph out.append([net_out.detach().numpy(), Y]) losses.append(loss) ax.clear() ax.plot(losses) ax.set_title("Mean Squared Error") fig.canvas.draw()
时间: 2024-04-20 09:23:27 浏览: 154
这段代码是一个训练过程的循环,用于训练Transformer模型并可视化损失。
首先,创建了一个`matplotlib`的图形对象`fig`,并添加一个子图`ax`。然后,使用`plt.ion()`函数将图形设置为交互模式,以便在训练过程中实时更新图形。接下来,使用`fig.show()`显示图形,并使用`fig.canvas.draw()`绘制图形。
然后,通过一个外循环进行多个训练周期(epochs)的训练。在每个循环内部,又有一个内循环用于遍历训练数据的批次。
在内循环中,首先使用`optimizer.zero_grad()`将模型参数的梯度归零。然后,通过调用`get_data()`函数获取输入数据和目标数据。接下来,通过前向传播计算模型输出`net_out`,并计算损失值`loss`(使用均方误差)。然后,通过调用`loss.backward()`进行反向传播并更新模型参数(使用`optimizer.step()`)。最后,将模型输出和目标数据添加到列表`out`中,并将损失值添加到列表`losses`中。
在内循环的末尾,使用`ax.clear()`清除子图中的内容,并使用`ax.plot()`绘制损失值的曲线。然后,设置子图的标题为“Mean Squared Error”,并通过`fig.canvas.draw()`更新图形。
这段代码的目的是实时可视化训练过程中的损失值曲线。如果你有关于代码的其他问题,请继续提问。
相关问题
fig = plt.figure(1, figsize=(20, 8)) ax1 = plt.subplot(121) ax2 = plt.subplot(122)如何设计让两个子图贴的在一起
这段代码首先创建了一个名为`fig`的新图形,设置了其大小为20英寸宽,8英寸高,并将其分配给编号1。然后,通过`plt.subplot(121)`和`plt.subplot(122)`,分别在该图中创建了两个子图,每个子图占据整个图的一半,分成了1行2列布局中的第一个和第二个区域。
如果你想要这两个子图紧密相邻,而不是默认地水平并排,可以使用`gridspec`模块,它提供了更精细的网格控制。例如:
```python
from matplotlib.gridspec import GridSpec
fig = plt.figure(1, figsize=(20, 8))
gs = GridSpec(1, 2, width_ratios=[1, 1]) # 创建一个1行2列的网格,等宽
ax1 = fig.add_subplot(gs[0])
ax1.set_title('Subplot 1') # 设置子图标题
ax2 = fig.add_subplot(gs[1])
ax2.set_title('Subplot 2')
# 如果你想让两个子图顶部对齐,可以使用下面的方式堆叠
# ax2.set_position([0.05, 0.1, 1, 0.85]) # 将ax2的位置设置在ax1下方
plt.tight_layout() # 自动调整子图间距以避免重叠
import matplotlib.pyplot as plt fig=plt.figure() ax1=fig.add_subplot(2,2,1) ax2=fig.add_subplot(2,2,2) ax3=fig.add_subplot(2,2,3) ax1.plot([1.5,2,3.5,-1,1.6])
这段代码的作用是创建一个包含3个子图的画布,并在第一个子图中绘制一个折线图。其中,fig.add_subplot(2,2,1)表示在2行2列的画布中创建第1个子图,ax1.plot([1.5,2,3.5,-1,1.6])表示在第1个子图中绘制一条由给定点构成的折线。需要注意的是,由于没有给定x轴坐标,所以默认使用数据点的下标作为x轴坐标。如果需要自定义x轴坐标,可以将x轴坐标作为第一个参数传给plot()函数。
阅读全文