什么是torch.stack函数?
时间: 2024-03-27 19:33:35 浏览: 183
torch.stack()的官方解释,详解以及例子
torch.stack函数是PyTorch中的一个函数,用于将多个张量按照指定的维度进行堆叠。它接受一个张量列表作为输入,并返回一个新的张量,其中包含了输入张量在指定维度上的堆叠结果。
具体来说,torch.stack函数的作用是将多个张量在指定维度上进行拼接,生成一个新的张量。拼接的方式是将输入张量按照指定维度进行堆叠,形成一个新的维度。这个新的维度的大小等于输入张量个数。
下面是torch.stack函数的使用示例:
```
import torch
# 创建两个张量
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
# 使用torch.stack函数将两个张量在0维度上进行堆叠
z = torch.stack([x, y], dim=0)
print(z)
```
输出结果为:
```
tensor([[1, 2, 3],
[4, 5, 6]])
```
在这个示例中,我们创建了两个形状相同的一维张量x和y。然后使用torch.stack函数将它们在0维度上进行堆叠,生成了一个新的二维张量z。z的形状为(2, 3),其中第一行是x的元素,第二行是y的元素。
阅读全文