unsqueeze函数
时间: 2023-08-11 07:12:39 浏览: 92
详解pytorch中squeeze()和unsqueeze()函数介绍
unsqueeze函数是PyTorch中的一个函数,用于在张量的维度上添加一个维度。它的语法如下:
```python
torch.unsqueeze(input, dim)
```
其中,input是一个张量,dim是要在哪个维度上添加新维度的索引。例如,如果input是一个形状为(3,4)的张量,我们可以使用unsqueeze函数在第一个维度上添加新维度,使其形状变为(1,3,4),代码如下:
```python
import torch
x = torch.randn(3, 4)
y = torch.unsqueeze(x, 0)
print(y.shape) # 输出:torch.Size([1, 3, 4])
```
在这个例子中,我们在第0个维度上添加了一个新维度。我们也可以在其他维度上添加新维度。例如,如果我们想在第二个维度上添加新维度,我们可以这样做:
```python
import torch
x = torch.randn(3, 4)
y = torch.unsqueeze(x, 2)
print(y.shape) # 输出:torch.Size([3, 4, 1])
```
在这个例子中,我们在第2个维度上添加了一个新维度。注意,添加的新维度的大小是1。这是因为在PyTorch中,维度的大小必须是整数,不能是None或其他非整数值。
阅读全文