torch,unsqueeze
时间: 2024-01-03 18:03:20 浏览: 67
详解pytorch中squeeze()和unsqueeze()函数介绍
torch.unsqueeze 是 PyTorch 中的一个函数,它用于在给定维度上对张量进行扩展。这个函数的作用是在指定维度上插入一个大小为1的新维度,从而增加张量的维度。
它的语法如下:
```
torch.unsqueeze(input, dim)
```
其中,input 是输入的张量,dim 是要插入新维度的位置。
举个例子,假设我们有一个形状为 (3,) 的一维张量,我们可以使用 unsqueeze 函数在第一个维度上插入一个新维度,将其转换为形状为 (1, 3) 的二维张量。代码如下:
```python
import torch
x = torch.tensor([1, 2, 3])
x = torch.unsqueeze(x, 0)
print(x.shape) # 输出:torch.Size([1, 3])
```
在这个例子中,我们将原来的一维张量 x 在第一个维度上插入了一个新维度,得到了一个形状为 (1, 3) 的二维张量。
阅读全文