def train(net, train_iter, val_iter, num_epochs, lr, wd, devices, lr_period, lr_decay): global val_acc, metric net.to(devices[0]) trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay) num_batches, timer = len(train_iter), d2l.Timer() train_loss, train_accs, val_accs = [], [], [] fig, ax = plt.subplots(figsize=(8, 6)) for epoch in range(num_epochs): net.train() metric = d2l.Accumulator(3) for i, (features, labels) in enumerate(train_iter): trainer.zero_grad() features, labels = features.to(devices[0]), labels.to(devices[0]) l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices) metric.add(l, acc, labels.shape[0]) train_loss.append(metric[0] / metric[2]) train_accs.append(metric[1] / metric[2]) if val_iter is not None: val_acc = d2l.evaluate_accuracy_gpu(net, val_iter) val_accs.append(val_acc) ax.plot(range(1, epoch + 2), [train_loss, train_accs, val_accs]) ax.set_xlabel('epoch') ax.set_ylabel('value') ax.legend(['train loss', 'train acc', 'val acc']) plt.show() scheduler.step() measures = ( f'train loss {metric[0] / metric[2]:.3f},'f'train acc {metric[1] / metric[2]:.3f}') if val_iter is not None: measures += f', val acc {val_acc :.3f}' print(以上代码能否实现时事图像展示?请优化成不等epoch结束就时事显示图像的代码
时间: 2024-03-09 07:46:28 浏览: 110
可以实现时事图像展示,且可以实现不等到 epoch 结束就实时显示图像。你可以在每训练一定批次后,调用 `plt.plot` 和 `plt.pause` 函数来显示当前的图像。例如,可以在每训练 100 个 batch 后,调用 `plt.plot` 和 `plt.pause(0.1)` 函数来显示当前的图像。同时,要在循环结束后再调用 `plt.show` 函数来显示最终的图像。
下面是代码实现示例:
```python
def train(net, train_iter, val_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):
global val_acc, metric
net.to(devices[0])
trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
num_batches, timer = len(train_iter), d2l.Timer()
train_loss, train_accs, val_accs = [], [], []
fig, ax = plt.subplots(figsize=(8, 6))
plt.ion() # 开启交互模式
for epoch in range(num_epochs):
net.train()
metric = d2l.Accumulator(3)
for i, (features, labels) in enumerate(train_iter):
trainer.zero_grad()
features, labels = features.to(devices[0]), labels.to(devices[0])
l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)
metric.add(l, acc, labels.shape[0])
if (i + 1) % 100 == 0: # 每训练 100 个 batch,实时显示图像
train_loss.append(metric[0] / metric[2])
train_accs.append(metric[1] / metric[2])
if val_iter is not None:
val_acc = d2l.evaluate_accuracy_gpu(net, val_iter)
val_accs.append(val_acc)
ax.clear()
ax.plot(range(1, len(train_loss) + 1), train_loss, label='train loss')
ax.plot(range(1, len(train_accs) + 1), train_accs, label='train acc')
if val_iter is not None:
ax.plot(range(1, len(val_accs) + 1), val_accs, label='val acc')
ax.set_xlabel('batch')
ax.set_ylabel('value')
ax.legend()
plt.pause(0.1)
scheduler.step()
measures = (f'train loss {metric[0] / metric[2]:.3f},'
f'train acc {metric[1] / metric[2]:.3f}')
if val_iter is not None:
measures += f', val acc {val_acc:.3f}'
print(measures)
plt.ioff() # 关闭交互模式
plt.show() # 显示最终的图像
```
在上述代码中,我们使用了 `plt.ion()` 开启交互模式,使用 `plt.pause(0.1)` 实时显示图像,使用 `plt.ioff()` 关闭交互模式,最后再使用 `plt.show()` 显示最终的图像。同时,我们在每训练 100 个 batch 后,清空画布并重新绘制图像。
阅读全文