torch.cat((A,B), dim=1)和torch.cat((A,B), dim=-1)分别代表什么?
时间: 2024-10-09 07:11:28 浏览: 61
`torch.cat((A, B), dim=1)` 和 `torch.cat((A, B), dim=-1)` 是PyTorch库中用于连接(concatenate)张量(tensors)的两个常用方式。
1. 当 `dim=1` 时[^1],`torch.cat` 沿着张量的列(columns)方向堆叠 A 和 B。这意味着每个输入张量的第 i 行会被拼接到一起,形成一个新的张量,其宽度(width)是原张量的宽度之和,高度(height)保持不变。这对于合并具有相同高度但不同宽度的一维或多维张量非常有用。
2. 当 `dim=-1` 或不指定 `dim` 时[^2],默认情况下,`torch.cat` 会沿张量的深度(depth)方向堆叠 A 和 B。这通常适用于二维以上的张量(如卷积神经网络中的多通道图像),在这种情况下,它会沿着最后一个维度(通常是颜色通道数或特征通道数)进行堆叠,扩展张量的高度和宽度而不改变深度。
举个例子:
```python
# 假设A和B都是形状为(2, 3)的张量
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([[7, 8, 9], [10, 11, 12]])
# dim=1 堆叠列
cat_dim1 = torch.cat((A, B), dim=1)
print(cat_dim1.shape) # 输出 (2, 6)
# dim=-1 或不指定dim 堆叠深度
cat_dim_minus1 = torch.cat((A, B))
print(cat_dim_minus1.shape) # 输出 (4, 3)
```
阅读全文