池化pytorch代码
时间: 2023-12-07 11:38:42 浏览: 68
对Pytorch中Tensor的各种池化操作解析
以下是PyTorch中池化的代码示例:
1. 最大池化
```python
import torch
import torch.nn as nn
# 定义输入数据
input_data = torch.randn(1, 3, 5, 5)
# 定义最大池化层
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 进行最大池化操作
output_data = max_pool(input_data)
# 输出结果
print(output_data)
```
2. 平均池化
```python
import torch
import torch.nn as nn
# 定义输入数据
input_data = torch.randn(1, 3, 5, 5)
# 定义平均池化层
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
# 进行平均池化操作
output_data = avg_pool(input_data)
# 输出结果
print(output_data)
```
3. 全局平均池化
```python
import torch
import torch.nn as nn
# 定义输入数据
input_data = torch.randn(1, 3, 5, 5)
# 定义全局平均池化层
global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# 进行全局平均池化操作
output_data = global_avg_pool(input_data)
# 输出结果
print(output_data)
```
阅读全文