torch.stack怎么用
时间: 2023-11-09 21:09:27 浏览: 117
看完秒懂torch.stack()
`torch.stack`函数可以用来将一系列张量按照指定的维度进行堆叠。它的用法如下:
```python
torch.stack(seq, dim=0, out=None)
```
- `seq`是一个张量序列,可以是一个Python列表、元组或者一个张量的迭代器。
- `dim`是一个整数,表示堆叠的维度。默认值为0,表示在新创建的张量中增加一个维度。
- `out`是一个可选的输出张量,用于指定结果的存放位置。
例如,假设有两个形状为(3,)的张量`a`和`b`,我们可以使用`torch.stack`将它们沿着新创建的维度0进行堆叠:
```python
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.stack([a, b], dim=0)
print(result)
```
输出结果为:
```
tensor([[1, 2, 3],
[4, 5, 6]])
```
阅读全文