torch.squeeze
时间: 2023-10-16 16:25:03 浏览: 36
torch.squeeze is a PyTorch function that removes all dimensions of size 1 from a tensor. It returns a new tensor with the same data but with a reduced number of dimensions. If the input tensor has multiple dimensions of size 1, all of them will be removed.
For example, consider a tensor of shape (1, 3, 1, 5), where the second dimension has size 3 and the fourth dimension has size 5. Using torch.squeeze on this tensor will result in a new tensor of shape (3, 5), with the size-1 dimensions removed.
Here's an example usage of torch.squeeze:
```
import torch
# create a tensor with shape (1, 3, 1, 5)
x = torch.randn(1, 3, 1, 5)
# apply torch.squeeze to remove the size-1 dimensions
y = torch.squeeze(x)
print(x.shape) # prints (1, 3, 1, 5)
print(y.shape) # prints (3, 5)
```