torch.unsqueeze()函数的用法
时间: 2023-10-16 12:03:41 浏览: 148
`torch.unsqueeze()`函数是PyTorch中用于增加张量维度的函数。它可以在指定位置插入一个大小为1的新维度,从而改变张量的形状。
以下是`torch.unsqueeze()`函数的用法示例:
```python
import torch
# 创建一个2维张量
x = torch.tensor([[1,2, 3],
[4, 5, 6]])
print(x.size())
# 输出: torch.Size([2, 3])
# 在第二个维度插入一个新维度
x_new = torch.unsqueeze(x, dim=1)
print(x_new.size())
# 输出: torch.Size([2, 1, 3])
# 在第一个维度插入一个新维度
x_new2 = torch.unsqueeze(x, dim=0)
print(x_new2.size())
# 输出: torch.Size([1, 2, 3])
```
在上述示例中,我们首先创建了一个形状为`(2, 3)`的2维张量`x`。然后,我们使用`torch.unsqueeze()`函数在不同的维度位置插入了新的维度。
- `torch.unsqueeze(x, dim=1)`表示在第二个维度上插入一个新维度。结果是一个形状为`(2, 1, 3)`的张量`x_new`。
- `torch.unsqueeze(x, dim=0)`表示在第一个维度上插入一个新维度。结果是一个形状为`(1, 2, 3)`的张量`x_new2`。
通过使用`torch.unsqueeze()`函数,我们可以根据需要在不同的位置插入新的维度,从而改变张量的形状。这在某些情况下非常有用,例如在进行广播操作或与其他形状不同的张量进行运算时。
阅读全文