for i, (inputs, labels) in enumerate(train_loader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item()改写
时间: 2024-01-28 17:05:13 浏览: 123
Python库 | torchvision-0.11.1-cp36-cp36m-macosx_10_9_x86_64.whl
你可以将这段代码改写为:
```
for i, data in enumerate(train_loader):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
```
这样就可以避免出现"ValueError: too many values to unpack"的错误了。在这个版本中,我们直接从dataloader中获取数据,并且使用inputs和labels直接引用data元组中的元素。
阅读全文