pytorch 常用损失函数
时间: 2023-12-19 08:31:51 浏览: 98
Pytorch中torch.nn的损失函数
5星 · 资源好评率100%
PyTorch中常用的损失函数包括回归损失函数和分类损失函数。其中,回归损失函数包括L1Loss和MSELoss,分类损失函数包括CrossEntropyLoss。此外,还有一种常用的平滑L1损失函数,也被称为Huber损失函数。
以下是PyTorch中常用的损失函数的介绍和演示:
1. L1Loss:计算预测值和真实值之间的绝对误差的平均值。
```python
import torch.nn as nn
loss_fn = nn.L1Loss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
loss = loss_fn(input, target)
print(loss)
```
2. MSELoss:计算预测值和真实值之间的均方误差的平均值。
```python
import torch.nn as nn
loss_fn = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
loss = loss_fn(input, target)
print(loss)
```
3. CrossEntropyLoss:用于多分类问题,计算预测值和真实值之间的交叉熵损失。
```python
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
loss = loss_fn(input, target)
print(loss)
```
4. SmoothL1Loss:平滑L1损失函数,也被称为Huber损失函数,可以减少异常值的影响。
```python
import torch.nn as nn
loss_fn = nn.SmoothL1Loss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
loss = loss_fn(input, target)
print(loss)
```
阅读全文