pytorch unsqueeze函数
时间: 2023-10-16 16:11:58 浏览: 133
详解pytorch中squeeze()和unsqueeze()函数介绍
`unsqueeze`函数是PyTorch中的一个操作,它可以在指定维度上增加一个新的维度。具体来说,`unsqueeze`可以在一个tensor上增加一个新的维度,使得原来的tensor变成一个更高维度的tensor。
`unsqueeze`的语法如下:
```python
torch.unsqueeze(input, dim)
```
其中,`input`是要增加维度的tensor,`dim`是要在哪个维度上增加新的维度。
下面是一个例子:
```python
import torch
# 创建一个2x3的tensor
x = torch.Tensor([[1, 2, 3], [4, 5, 6]])
# 在第0维上增加一个新的维度
y = torch.unsqueeze(x, 0)
# 打印y的形状
print(y.shape)
```
输出:
```
torch.Size([1, 2, 3])
```
在上面的例子中,我们创建了一个2x3的tensor,然后在第0维上增加了一个新的维度,所以输出的形状是`torch.Size([1, 2, 3])`。
阅读全文