torch.stack(imgs)
时间: 2023-09-09 21:09:28 浏览: 93
看完秒懂torch.stack()
torch.stack(imgs)是PyTorch库中的一个函数,用于将一系列张量按照新的维度进行堆叠,创建一个新的张量。
函数的语法如下:
torch.stack(seq, dim=0, *, out=None) -> Tensor
其中:
- seq:要进行堆叠的张量序列,可以是一个列表、元组或者其他可迭代对象。
- dim:指定在哪个维度上进行堆叠,默认为0,即在新的维度0上进行堆叠。
- out(可选):输出张量。
示例用法:
```python
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.tensor([7, 8, 9])
stacked = torch.stack((x, y, z))
print(stacked)
# 输出:tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
```
在上述示例中,我们有三个张量x、y和z,它们都是形状为(3,)的一维张量。通过调用torch.stack函数,我们将它们按照维度0进行堆叠,创建一个新的二维张量stacked。stacked的形状为(3, 3),每行对应输入的一个张量。注意,在堆叠时,输入张量的形状必须保持一致。
通过torch.stack函数,我们可以方便地将多个张量按照指定维度进行堆叠,用于处理多个样本或者批量操作。
阅读全文