介绍torch.stack()函数
时间: 2023-10-20 22:24:54 浏览: 95
torch.stack()函数是一个将张量序列按照指定维度堆叠起来的函数。它的语法如下:
```
torch.stack(seq, dim=0, out=None)
```
其中,seq是一个张量序列,dim是堆叠维度,默认为0,out是输出张量,默认为None。
举个例子,假设我们有两个张量:
```
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
```
如果我们想将它们按照第0维堆叠起来,可以使用torch.stack()函数:
```
c = torch.stack((a, b), dim=0)
```
这样,c就变成了一个2×3的张量:
```
tensor([[1, 2, 3],
[4, 5, 6]])
```
torch.stack()函数也可以用于高维张量的堆叠。例如,假设我们有一个3×2×2的张量a和一个3×2×2的张量b:
```
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]])
b = torch.tensor([[[13, 14], [15, 16]], [[17, 18], [19, 20]], [[21, 22], [23, 24]]])
```
如果我们想将它们按照第0维堆叠起来,可以使用torch.stack()函数:
```
c = torch.stack((a, b), dim=0)
```
这样,c就变成了一个2×3×2×2的张量:
```
tensor([[[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]],
[[[13, 14],
[15, 16]],
[[17, 18],
[19, 20]],
[[21, 22],
[23, 24]]]])
```
需要注意的是,torch.stack()函数会在堆叠维度上创建一个新的维度。在上面的例子中,a和b都是3×2×2的张量,而堆叠后的张量c是2×3×2×2的张量。如果我们不想在堆叠维度上创建新的维度,可以使用torch.cat()函数。
阅读全文