torch.stack的用法
时间: 2023-10-20 22:14:03 浏览: 81
对PyTorch torch.stack的实例讲解
`torch.stack()` 是 PyTorch 中的一个函数,它可以将多个张量按照指定的维度进行堆叠(stack)操作,生成一个新的张量。
具体用法如下:
```python
torch.stack(tensors, dim=0, out=None)
```
其中,`tensors` 是一个由张量组成的序列,`dim` 是指定堆叠的维度,默认为 0,表示在新张量的第 0 维度上进行堆叠。`out` 是可选参数,如果指定了,输出将会被写入它,而不是新创建一个张量。
例如,假设有两个张量 `a` 和 `b`,它们的形状分别为 (3, 4) 和 (3, 4),我们可以使用 `torch.stack()` 将它们在第 0 维度上进行堆叠,生成一个形状为 (2, 3, 4) 的新张量:
```python
import torch
a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.stack([a, b], dim=0)
print(c.shape) # 输出:torch.Size([2, 3, 4])
```
需要注意的是,要使用 `torch.stack()` 进行堆叠操作,输入的张量必须有相同的形状,除了指定的维度外,其它维度的大小必须相等。
阅读全文