torch.nn.MSE用法
时间: 2023-03-24 19:01:29 浏览: 243
torch.nn.MSE是PyTorch中用于计算均方误差(Mean Squared Error,MSE)的函数。MSE通常用于衡量模型预测结果与真实值之间的误差。
使用torch.nn.MSE函数时,需要输入两个张量,分别是模型的预测值和真实值。该函数将返回一个标量,即这两个张量之间的均方误差。
下面是一个示例代码,演示如何使用torch.nn.MSE函数计算两个张量之间的均方误差:
``` python
import torch
# 模型预测值和真实值
pred = torch.tensor([1.0, 2.0, 3.0])
target = torch.tensor([2.0, 4.0, 6.0])
# 计算均方误差
mse_loss = torch.nn.MSELoss()
loss = mse_loss(pred, target)
print(loss)
```
输出结果为:
```
tensor(4.6667)
```
以上代码中,我们首先定义了两个张量pred和target,分别表示模型的预测值和真实值。接着,我们创建了一个torch.nn.MSELoss对象mse_loss,并使用它来计算pred和target之间的均方误差。最后,我们打印了计算结果loss。
需要注意的是,torch.nn.MSE函数返回的是一个标量张量,而不是一个Python数值。如果需要将结果转换为Python数值,可以使用loss.item()方法。
阅读全文