torch.squeeze()函数的用法
时间: 2023-12-21 10:05:20 浏览: 70
Pytorch中torch.gather函数
5星 · 资源好评率100%
`torch.squeeze()`函数是PyTorch中用于压缩张量维度的函数。它可以去除张量中维度大小为1的维度,从而减少张量的维度数。
以下是`torch.squeeze()`函数的用法示例:
```python
import torch
# 创建一个维度为1的张量
x = torch.tensor([[[1, 2, 3]]])
print(x.size())
# 输出: torch.Size([1, 1, 3])
# 压缩维度大小为1的维度
x_squeezed = torch.squeeze(x)
print(x_squeezed.size())
# 输出: torch.Size([3])
# 创建一个维度为1的张量
y = torch.tensor([[[[4, 5, 6]]]])
print(y.size())
# 输出: torch.Size([1, 1, 1, 3])
# 压缩维度大小为1的维度
y_squeezed = torch.squeeze(y)
print(y_squeezed.size())
# 输出: torch.Size([3])
```
在上述示例中,我们首先创建了一个形状为`(1, 1, 3)`的3维张量`x`和一个形状为`(1, 1, 1, 3)`的4维张量`y`。然后,我们使用`torch.squeeze()`函数去除了这些张量中维度大小为1的维度。
- 对于张量`x`,`torch.squeeze(x)`会返回一个形状为`(3,)`的张量`x_squeezed`,因为在维度1和2上的大小为1的维度被去除了。
- 对于张量`y`,`torch.squeeze(y)`会返回一个形状为`(3,)`的张量`y_squeezed`,因为在维度1、2和3上的大小为1的维度被去除了。
通过使用`torch.squeeze()`函数,我们可以根据需要去除张量中不需要的维度,从而减少张量的维度数。这在某些情况下非常有用,例如在进行模型输出后处理或与其他形状不同的张量进行运算时。
阅读全文