torch.cat dim参数的作用
时间: 2024-09-15 11:08:59 浏览: 59
torch.cat()函数的官方解释,详解以及例子
`torch.cat()`函数中的`dim`参数是关键,它控制了沿着哪个维度进行张量的拼接。具体来说:
- 当`dim=0`时(默认),张量将在第一个轴(通常是行或样本轴)上拼接,这对于逐层堆叠多维数组非常有用。
- 如果`dim`设置为一个正整数,则沿该特定维度进行拼接。例如,如果你有一个三维张量(N x C x H x W),设置`dim=1`会按通道(channels)合并张量。
- 对于负数索引,从最后一个维度开始计数。例如,`dim=-1`表示在深度(depth)方向上拼接。
举个例子[^2]:
```python
# 假设我们有三个一维张量,每个代表一个人的身高
tensor1 = torch.tensor([55])
tensor2 = torch.tensor([60])
tensor3 = torch.tensor([65])
# 使用dim=0拼接它们,相当于创建一个新的行向量,包含三个人的身高
result = torch.cat((tensor1, tensor2, tensor3), dim=0)
```
在这个例子中,`result`将会是一个形状为`(3,)`的新张量,包含了原始张量的内容。
阅读全文