unsqueeze函数什么用
时间: 2023-10-16 18:18:25 浏览: 72
pytorch张量索引以及sum函数mean函数unsqueeze函数
unsqueeze函数是PyTorch中的一个函数,用于将一个张量的维度进行扩展。具体来说,unsqueeze函数可以将一个长度为n的一维张量扩展为一个(n,1)的二维张量,或者将一个(n,m)的二维张量扩展为一个(n,m,1)的三维张量,以此类推。
举个例子,假设有一个shape为(3,)的一维张量a,我们可以使用unsqueeze函数将其变成一个shape为(3,1)的二维张量b:
```
import torch
a = torch.tensor([1, 2, 3])
b = torch.unsqueeze(a, dim=1)
print(b.shape) # output: torch.Size([3, 1])
```
在这个例子中,dim参数指定了要在哪个维度上进行扩展,这里指定为1,即在第二个维度上进行扩展。
阅读全文