为什么这段代码会报错:AttributeError: 'Mnist_NN' object has no attribute 'tain'
时间: 2024-12-16 14:05:17 浏览: 8
这个错误是因为在`fit`函数中的`model.train()`行,代码试图调用`Mnist_NN`对象的`tain`属性,但实际上这个对象并没有这样的属性。应该是误拼写了`train`方法,而不是`tain`。正确的语法应为`model.train()`,这表示设置模型处于训练模式,以便在接下来的迭代中启用正向传播和反向传播。
修复这个问题后,代码应该能正常运行。完整的修正后的`fit`函数应该是这样的:
```python
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train() # 将这里修改为 'model.train()'
for xb, yb in train_dl:
loss_bach(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(*[loss_bach(model, loss_func, xb, yb) for xb, yb in valid_dl])
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print('当前step:' + str(step), '验证集损失:' + str(val_loss))
pass
```
现在,当模型进入训练模式时,将会执行期望的训练流程。
阅读全文