torch中的unsqueeze如何使用
时间: 2023-10-16 10:29:07 浏览: 83
torch 中各种图像格式转换的实现方法
在PyTorch中,`unsqueeze()`函数可以用来在指定维度上增加一个维度。具体的使用方法如下:
```python
import torch
# 定义一个二维张量
tensor_2d = torch.tensor([[1, 2], [3, 4]])
# 将二维张量在第0维增加一个维度
tensor_3d = tensor_2d.unsqueeze(0)
print(tensor_3d)
```
运行结果如下:
```
tensor([[[1, 2],
[3, 4]]])
```
在上面的例子中,我们将二维张量`tensor_2d`在第0维增加一个维度,得到了一个三维张量`tensor_3d`。`unsqueeze()`函数的参数是要增加维度的索引位置,从0开始。如果要在最后一个维度上增加一个维度,可以使用`-1`作为参数:
```python
# 将二维张量在最后一个维度增加一个维度
tensor_3d = tensor_2d.unsqueeze(-1)
print(tensor_3d)
```
运行结果如下:
```
tensor([[[1],
[2]],
[[3],
[4]]])
```
可以看到,此时我们得到了一个三维张量,最后一个维度上增加了一个维度。
阅读全文