如何拼接多个tensor,用python具体演示
时间: 2024-04-11 13:28:50 浏览: 64
可以使用`torch.cat`函数来拼接多个tensor。下面是一个具体的示例代码:
```python
import torch
# 创建两个示例tensor
tensor1 = torch.tensor([[1, 2, 3],
[4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9],
[10, 11, 12]])
# 在行维度上拼接两个tensor
concatenated_tensor = torch.cat((tensor1, tensor2), dim=0)
print(concatenated_tensor)
# 在列维度上拼接两个tensor
concatenated_tensor = torch.cat((tensor1, tensor2), dim=1)
print(concatenated_tensor)
```
运行以上代码,输出结果如下:
```
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
```
`torch.cat`函数的第一个参数是一个tuple或list,包含要拼接的tensor。第二个参数`dim`指定了拼接的维度,0表示在行维度上拼接,1表示在列维度上拼接。以上示例分别在行维度和列维度上将两个tensor拼接在一起。
阅读全文