torch.unsqueeze()的用法
时间: 2023-10-16 15:32:38 浏览: 135
详解pytorch中squeeze()和unsqueeze()函数介绍
`.unsqueeze()` 是 PyTorch 中的一个函数,用于在指定维度上增加一个维度。具体来说,它的作用是将一个张量(tensor)在指定的维度上增加一个大小为 1 的维度,从而扩展张量的形状。
函数的语法如下:
```python
torch.unsqueeze(input, dim)
```
其中,`input` 是要增加维度的张量,`dim` 是要增加的维度。`dim` 的取值范围是 `[-(input.dim() + 1), input.dim()]`,也就是说,可以在张量的任何一个维度上增加一个维度。
例如,假设有一个形状为 (3, 4) 的张量 `x`,要在第 0 维(即行)上增加一个维度,可以这样写:
```python
import torch
x = torch.randn(3, 4)
y = torch.unsqueeze(x, 0)
print(x.shape) # 输出 (3, 4)
print(y.shape) # 输出 (1, 3, 4)
```
可以看到,`y` 的形状是在 `x` 的第 0 维上增加了一个维度,大小为 1。这样做的效果是将原来的 3 行 4 列的矩阵变成了一个 1 个通道、3 行 4 列的张量。
阅读全文