torch.cat()当dim=-1和dim=2时如何拼接
时间: 2024-09-26 22:04:56 浏览: 35
torch.cat()函数的官方解释,详解以及例子
`torch.cat()`函数是PyTorch库中的一个操作,用于将一维、二维或多维张量按照指定维度(dim)连接起来。当你设置`dim=-1`时,它会在最后一个(即最右边)维度上进行拼接,这意味着沿着列方向添加新的元素到现有的张量。例如,如果你有一个形状为`(batch_size, channels, height, width)`的张量列表,`torch.cat(tensors, dim=-1)`会沿深度(通道)方向堆叠所有张量。
另一方面,当你设置`dim=2`时,这表示在第二维度(如果张量有三维或以上的话,通常对应于宽度或列数)进行拼接。这对于将一系列长度相同的行向量(如时间序列数据)横向拼接到一起非常有用。
举个例子:
```python
# 假设我们有两个3x4的张量list
tensor_list_1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
tensor_list_2 = torch.tensor([[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]])
# 当dim=-1时
concat_dim_minus_one = torch.cat(tensor_list_1, dim=-1)
print(concat_dim_minus_one.shape) # 输出 (3, 8)
# 当dim=2时
concat_dim_two = torch.cat(tensor_list_1, dim=2)
print(concat_dim_two.shape)
阅读全文