unsqueeze()用法
时间: 2023-06-30 17:17:24 浏览: 33
PyTorch的torch.cat用法
5星 · 资源好评率100%
unsqueeze()是PyTorch中的一个函数,用于在张量的指定维度上插入一个新的维度。它的用法如下:
```python
unsqueeze(dim)
```
其中,`dim`参数是要插入新维度的维度索引,它应该是一个整数,范围从0到`tensor.dim()`,其中`tensor.dim()`是张量的维度数。新维度大小为1。
例如,如果我们有一个形状为`(3, 4)`的张量,我们可以使用unsqueeze()将它变成`(3, 4, 1)`的张量,代码如下:
```python
import torch
x = torch.randn(3, 4)
y = x.unsqueeze(2)
print(y.shape) # 输出 torch.Size([3, 4, 1])
```
在这个例子中,我们使用unsqueeze()函数在第2个维度上插入了一个新的维度。这将产生一个新的张量`y`,其形状为`(3, 4, 1)`。注意,第2个维度的大小现在变成了1。
阅读全文