torch里面有什么函数可以直接输出识别准确率和损失吗
时间: 2024-02-04 14:04:02 浏览: 87
在PyTorch中,通常使用nn.CrossEntropyLoss计算损失,使用torch.max函数找到预测的类别,然后通过比较预测的类别和真实标签计算准确率。以下是一个示例代码:
```
import torch
import torch.nn as nn
# 定义模型和损失函数
model = nn.Linear(10, 2)
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(num_epochs):
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
_, predicted = torch.max(outputs.data, 1)
total = labels.size(0)
correct = (predicted == labels).sum().item()
accuracy = correct / total
# 输出结果
print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
.format(epoch+1, num_epochs, loss.item(), accuracy*100))
```
在上面的代码中,我们使用torch.max函数找到每个样本的预测类别,并使用(predicted == labels).sum().item()计算正确分类的样本数。然后我们将正确分类的样本数除以总样本数,以得出准确率。
阅读全文