torch.stack的用法
时间: 2023-07-19 11:36:49 浏览: 71
对PyTorch torch.stack的实例讲解
torch.stack是一个用于将一组张量沿着新的维度进行堆叠的函数。具体来说,它可以将一组张量沿着指定维度进行堆叠,生成一个新的张量。该函数的语法如下:
```python
torch.stack(sequence, dim=0, out=None)
```
其中,sequence是一个张量序列,dim是指定沿着哪个维度进行堆叠,out是可选的输出张量。
以下是一个示例:
```python
import torch
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])
stacked_tensor = torch.stack([t1, t2, t3])
print(stacked_tensor)
```
输出结果为:
```
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
```
在上述示例中,我们将三个一维张量沿着第0个维度进行了堆叠,生成了一个三行三列的二维张量。
阅读全文