torch squeeze
时间: 2025-01-07 19:38:40 浏览: 5
### PyTorch `squeeze` 函数详解
在 PyTorch 中,`squeeze` 方法用于移除张量维度中的单维条目(即大小为 1 的维度)。这有助于调整数据形状以适应不同层的要求或简化计算。
#### 基本语法
```python
torch.squeeze(input, dim=None, *, out=None) -> Tensor
```
- 参数说明:
- `input`: 输入张量。
- `dim`: 可选参数;指定要压缩的维度。如果该维度不是单一维度,则不会发生改变。
当不提供 `dim` 参数时,所有尺寸为 1 的维度都将被删除[^1]。
#### 使用实例
下面通过几个例子来展示如何使用此功能:
##### 示例一:简单应用
创建一个三维张量并尝试去除其冗余的一维结构。
```python
import torch
x = torch.ones(2, 1, 3)
print("Original shape:", x.shape)
y = torch.squeeze(x)
print("Squeezed shape:", y.shape)
```
输出将是:
```
Original shape: torch.Size([2, 1, 3])
Squeezed shape: torch.Size([2, 3])
```
##### 示例二:特定维度操作
仅针对某一具体轴向执行挤压动作。
```python
z = torch.zeros((1, 4))
w = torch.squeeze(z, dim=0)
print(w.shape)
```
这段代码会打印出 `(4,)` 表明第一个维度已被成功消除。
阅读全文