unsqueeze():
时间: 2024-05-18 07:11:05 浏览: 29
pytorch张量索引以及sum函数mean函数unsqueeze函数
unsqueeze() 是 PyTorch 中的一个函数,用于在指定维度上增加维度的大小为 1。它可以用来改变张量的形状,特别是在进行某些操作之前需要保持维度一致性时很有用。
具体来说,unsqueeze() 函数会在给定的维度上插入一个大小为 1 的新维度。例如,假设有一个形状为 (3,) 的一维张量 t,使用 unsqueeze(0) 将会在第 0 维插入一个新维度,得到一个形状为 (1, 3) 的二维张量。
示例代码:
```python
import torch
t = torch.tensor([1, 2, 3])
print(t.shape) # 输出: torch.Size([3])
t_unsqueezed = t.unsqueeze(0)
print(t_unsqueezed.shape) # 输出: torch.Size([1, 3])
t_unsqueezed_2 = t.unsqueeze(1)
print(t_unsqueezed_2.shape) # 输出: torch.Size([3, 1])
```
这样,我们就可以通过 unsqueeze() 函数在指定的维度上扩展张量的形状,从而满足某些操作对形状的要求。
阅读全文