x.unsqueeze用法
时间: 2023-10-16 16:26:00 浏览: 39
在 PyTorch 中,unsqueeze() 是一个用于增加张量维度的函数。具体来说,unsqueeze() 函数可以将一个张量的维度扩展一个指定的大小。例如,如果你有一个形状为 (3, 4) 的张量 x,你可以使用 x.unsqueeze(0) 将其扩展为形状为 (1, 3, 4) 的三维张量。同样地,使用 x.unsqueeze(1) 可以将其扩展为形状为 (3, 1, 4) 的三维张量。
unsqueeze() 函数的使用格式为:unsqueeze(dim),其中 dim 参数指定要扩展的维度。具体来说,dim 参数应该是一个整数,表示要扩展的维度在输出张量中的位置。例如,dim=0 表示要将新维度添加到张量的最前面,dim=1 表示要将新维度添加到张量的第二个位置,以此类推。
下面是一个例子,展示了如何使用 unsqueeze() 函数将一个二维张量扩展为三维张量:
```python
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 将 x 扩展为形状为 (1, 2, 3) 的三维张量
y = x.unsqueeze(0)
print(x.shape) # 输出:torch.Size([2, 3])
print(y.shape) # 输出:torch.Size([1, 2, 3])
```
在上面的例子中,我们首先定义一个形状为 (2, 3) 的二维张量 x。然后,我们使用 unsqueeze(0) 将 x 扩展为形状为 (1, 2, 3) 的三维张量 y。最后,我们输出 x 和 y 的形状,以检查 unsqueeze() 函数的效果。