torch.nn.functional详解
时间: 2024-05-08 17:17:37 浏览: 125
torch.nn.functional是PyTorch中的一个模块,用于实现各种神经网络的函数,包括卷积、池化、激活、损失函数等。该模块中的函数是基于Tensor进行操作的,可以灵活地组合使用。
常用函数:
1.卷积函数:torch.nn.functional.conv2d
该函数用于进行二维卷积操作,输入包括输入张量、卷积核张量和卷积核大小等参数。示例代码如下:
```python
import torch.nn.functional as F
input = torch.randn(1, 1, 28, 28)
conv1 = nn.Conv2d(1, 6, 5)
output = F.conv2d(input, conv1.weight, conv1.bias, stride=1, padding=2)
```
2.池化函数:torch.nn.functional.max_pool2d
该函数用于进行二维最大池化操作,输入包括输入张量、池化核大小等参数。示例代码如下:
```python
import torch.nn.functional as F
input = torch.randn(1, 1, 28, 28)
output = F.max_pool2d(input, kernel_size=2, stride=2)
```
3.激活函数:torch.nn.functional.relu
该函数用于进行ReLU激活操作,输入包括输入张量等参数。示例代码如下:
```python
import torch.nn.functional as F
input = torch.randn(1, 10)
output = F.relu(input)
```
4.损失函数:torch.nn.functional.cross_entropy
该函数用于计算交叉熵损失,输入包括预测结果和真实标签等参数。示例代码如下:
```python
import torch.nn.functional as F
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
output = F.cross_entropy(input, target)
```
以上是torch.nn.functional模块中的一些常用函数,除此之外还有很多其他函数,可以根据需要进行查阅。
阅读全文