报错Cell In[2], line 120, in train(dataloader, model, loss_fn, optimizer) 118 cur_loss=F.cross_entropy(preds,labels) 119 pred1=torch.max(preds,axis=1) --> 120 cur_acc=torch.sum((labels==pred1).to(torch.float32))/preds.shape[0] 121 optimizer.zero_grad() 122 cur_loss.backward() AttributeError: 'bool' object has no attribute 'to'
时间: 2023-11-27 18:54:51 浏览: 42
这个错误可能是因为 `(labels==pred1)` 返回了一个布尔类型的张量,而 `to()` 方法只能用于张量类型的转换。可以尝试将 `(labels==pred1)` 改为 `torch.eq(labels, pred1)`,这样返回的就是一个张量,就可以使用 `to()` 方法进行类型转换了。例如:
```
cur_acc=torch.sum(torch.eq(labels,pred1).to(torch.float32))/preds.shape[0]
```
另外,需要注意 `torch.max(preds, axis=1)` 返回的是一个元组,包含两个张量,第一个张量是最大值,第二个张量是最大值对应的索引。如果需要获取最大值对应的索引,可以使用 `torch.max(preds, axis=1)[1]`。
相关问题
ttributeError Traceback (most recent call last) Cell In[3], line 160 158 print(f"epoch{t+1}\n----------") 159 # 训练模型 --> 160 train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer) 161 # 验证模型 162 val_loss, val_acc = val(val_dataloader, model, loss_fn) Cell In[3], line 125, in train(dataloader, model, loss_fn, optimizer) 122 cur_loss.backward() 123 optimizer.step() --> 125 loss+=loss.item() 126 current+=cur_acc.item() 127 n=n+1 AttributeError: 'int' object has no attribute 'item'
这个错误是因为你试图将一个Python整数转换为PyTorch张量,以便在计算损失时使用。请确保你将输入数据转换为PyTorch张量,以便进行计算。你可以使用`torch.tensor()`方法将Python对象转换为张量。在这种情况下,您可以尝试将`loss`变量初始化为一个张量,而不是整数,如下所示:
```
loss = torch.tensor(0.0)
```
这样就会将`loss`变量初始化为一个浮点数张量,并且可以使用`.item()`方法获得其值。
for epoch in range(num_epochs): model.train() train_loss = 0 train_mae = 0 for batch_data, batch_target in train_dataloader: optimizer.zero_grad() output = model(batch
_data) # 前向传播
loss = loss_fn(output, batch_target) # 计算loss
train_loss += loss.item()
mae = mean_absolute_error(output.detach().cpu().numpy(), batch_target.detach().cpu().numpy()) # 计算MAE
train_mae += mae
loss.backward() # 反向传播
optimizer.step() # 更新梯度
train_loss /= len(train_dataloader)
train_mae /= len(train_dataloader)
model.eval() # 切换到评估模式
eval_loss = 0
eval_mae = 0
with torch.no_grad():
for batch_data, batch_target in eval_dataloader:
output = model(batch_data)
loss = loss_fn(output, batch_target)
eval_loss += loss.item()
mae = mean_absolute_error(output.detach().cpu().numpy(), batch_target.detach().cpu().numpy())
eval_mae += mae
eval_loss /= len(eval_dataloader)
eval_mae /= len(eval_dataloader)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train MAE: {train_mae:.4f}, Eval Loss: {eval_loss:.4f}, Eval MAE: {eval_mae:.4f}")
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
torch.save(model.state_dict(), save_path)
print("Save model successfully!")
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)