举例说明torch.cat的语法是什么
时间: 2024-10-27 11:07:00 浏览: 64
`torch.cat()` 是 PyTorch 库中的一个功能函数,用于沿着指定轴将一维、二维或多维张量连接起来。其基本语法如下:
```python
torch.cat(tensors, dim=0, out=None)
```
参数解释:
- `tensors`: 一个张量列表或元组,它们要被连接在一起。
- `dim` (可选): 连接操作发生的维度,默认为 0,即沿行(对于二维及以上数据)。其他值(如 1 表示沿列)也可以选择。
- `out` (可选): 可选的目标张量,如果提供,则结果会存储在这个张量上而不是创建一个新的。
例如,如果你想把两个一维张量 `[1, 2, 3]` 和 `[4, 5, 6]` 沿着第一个轴(索引为 0)连接成 `[1, 2, 3, 4, 5, 6]`,你可以这样做:
```python
import torch
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
# 使用cat函数连接
concatenated_tensor = torch.cat((tensor1, tensor2), dim=0)
print(concatenated_tensor) # 输出: tensor([1, 2, 3, 4, 5, 6])
```
阅读全文