如何使用torch.squeeze()减少
时间: 2024-06-09 14:05:48 浏览: 84
squeezenet-torch:将SqueezeNet移植到PyTorch中;主要是为了我自己的学习
torch.squeeze()是PyTorch中的一个函数,用于减少张量的维度。当张量中存在大小为1的维度时,torch.squeeze()可以去除这些维度,从而减少张量的维度数量。
使用torch.squeeze()的方法很简单,只需要将要操作的张量作为输入参数传递给该函数即可。下面是一个示例:
```python
import torch
# 创建一个大小为(1, 3, 1, 2)的张量
x = torch.randn(1, 3, 1, 2)
print("原始张量的形状:", x.shape)
# 使用torch.squeeze()去除大小为1的维度
y = torch.squeeze(x)
print("去除维度后的形状:", y.shape)
```
输出结果如下:
```
原始张量的形状: torch.Size([1, 3, 1, 2])
去除维度后的形状: torch.Size([3, 2])
```
在上述示例中,原始张量x的形状为[1, 3, 1, 2],其中存在两个大小为1的维度。使用torch.squeeze()函数后,这两个大小为1的维度被去除,最终得到了形状为[3, 2]的张量y。
阅读全文