介绍torch.stack使用方法
时间: 2024-09-27 20:15:36 浏览: 45
`torch.stack()` 是 PyTorch 库中用于将一系列张量沿着指定维度堆叠在一起的操作。其基本语法如下:
```python
torch.stack(tensors, dim=0, out=None)
```
- `tensors`: 一个包含张量元素的列表、元组或其他可迭代的对象,这些张量需要有相同的形状除了在`dim`所指定的维度。
- `dim` (可选): 默认值为0,表示在新的轴上添加这些张量,如果你想要沿现有的轴堆叠,可以设置为非零整数。例如,如果你有一个二维张量列表,`dim=1`会将它们横着堆叠。
- `out` (可选): 可选的目标输出张量,如果提供,则结果将被存储在此张量中,而不是创建一个新的。
使用示例:
```python
import torch
# 创建一些一维张量
tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
# 使用dim=0将它们堆叠起来
stacked_tensor = torch.stack(tensors, dim=0)
print(stacked_tensor) # 输出: tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
```
如果你想沿其他维度堆叠,比如对于二维张量:
```python
tensors = [torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6], [7, 8]])]
stacked_tensor = torch.stack(tensors, dim=1) # 沿列堆叠
```
这将得到一个三阶张量,第一维是原始的张量数目,第二和第三维是原来的列。
阅读全文