torch。stack
时间: 2023-11-27 08:46:35 浏览: 43
`torch.stack()`是一个PyTorch中的函数,用于将张量序列沿着新的维度堆叠起来。它的语法如下:
```python
torch.stack(sequence, dim=0, out=None)
```
其中,`sequence`是一个张量序列,`dim`是新维度的索引,`out`是输出张量。
下面是一个例子,假设我们有两个张量`x`和`y`,它们的形状都是`(3,)`,我们可以使用`torch.stack()`将它们沿着新的维度堆叠起来:
```python
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.stack([x, y], dim=0)
print(z)
```
输出结果为:
```
tensor([[1, 2, 3],
[4, 5, 6]])
```
可以看到,`torch.stack()`将`x`和`y`沿着新的维度`dim=0`堆叠起来,形成了一个新的张量`z`,它的形状为`(2, 3)`。
除了`dim=0`,我们还可以指定其他的维度来堆叠张量序列。下面是一个例子,假设我们有三个张量`a`、`b`和`c`,它们的形状都是`(3, 2)`,我们可以使用`torch.stack()`将它们沿着新的维度`dim=2`堆叠起来:
```python
import torch
import numpy as np
a = torch.from_numpy(np.array([[1, 2], [3, 4], [5, 6]]))
b = torch.from_numpy(np.array([[7, 8], [9, 10], [11, 12]]))
c = torch.from_numpy(np.array([[13, 14], [15, 16], [17, 18]]))
d = torch.stack([a, b, c], dim=2)
print(d)
```
输出结果为:
```
tensor([[[ 1, 7, 13],
[ 2, 8, 14]],
[[ 3, 9, 15],
[ 4, 10, 16]],
[[ 5, 11, 17],
[ 6, 12, 18]]])
```
可以看到,`torch.stack()`将`a`、`b`和`c`沿着新的维度`dim=2`堆叠起来,形成了一个新的张量`d`,它的形状为`(3, 2, 3)`。
阅读全文