torch.squeeze
时间: 2023-10-16 11:31:34 浏览: 6
`torch.squeeze(input, dim=None, out=None)` 函数可以从张量中删除大小为 1 的维度。如果输入张量 `input` 在给定维度 `dim` 上的大小为 1,则该维度将被删除。如果未指定维度,则删除所有大小为 1 的维度。
参数:
- `input`:输入张量。
- `dim`:要删除的维度。可以是一个整数或一个元组。默认值为 `None`,表示删除所有大小为 1 的维度。
- `out`:输出张量。如果指定,则将结果存储在此张量中。
返回值:返回一个张量,该张量与输入张量 `input` 具有相同的数据,但删除了大小为 1 的维度。
示例:
```python
import torch
x = torch.randn(1, 3, 1, 2)
print(x.shape) # torch.Size([1, 3, 1, 2])
y = torch.squeeze(x)
print(y.shape) # torch.Size([3, 2])
z = torch.squeeze(x, dim=0)
print(z.shape) # torch.Size([3, 1, 2])
w = torch.squeeze(x, dim=2)
print(w.shape) # torch.Size([1, 3, 2])
```
在上面的示例中,`x` 是一个形状为 `(1, 3, 1, 2)` 的张量,其中第一个维度、第三个维度的大小都为 1。调用 `torch.squeeze(x)` 后,删除了所有大小为 1 的维度,因此得到一个形状为 `(3, 2)` 的张量 `y`。调用 `torch.squeeze(x, dim=0)`,只删除第一个维度,因此得到一个形状为 `(3, 1, 2)` 的张量 `z`。调用 `torch.squeeze(x, dim=2)`,只删除第三个维度,因此得到一个形状为 `(1, 3, 2)` 的张量 `w`。
阅读全文