torch.squeeze函数
时间: 2023-10-16 22:31:56 浏览: 151
Pytorch中torch.gather函数
5星 · 资源好评率100%
torch.squeeze函数是一个PyTorch中的函数,它的作用是从张量中移除维度为1的维度。它的语法如下:
```python
torch.squeeze(input, dim=None, out=None)
```
其中,参数input是要进行操作的张量,dim是要移除的维度的索引号,如果不指定,则移除所有维度为1的维度,out是输出张量,如果不指定,则返回一个新的张量。
示例代码:
```python
import torch
# 创建一个形状为[1, 3, 1, 2]的张量
x = torch.randn(1, 3, 1, 2)
# 移除所有维度为1的维度
y = torch.squeeze(x)
print("x的形状:", x.shape)
print("y的形状:", y.shape)
```
输出结果:
```
x的形状: torch.Size([1, 3, 1, 2])
y的形状: torch.Size([3, 2])
```
可以看到,原来形状为[1, 3, 1, 2]的张量经过squeeze函数后,形状变为了[3, 2]。
阅读全文