pytorch的cat函数代码
时间: 2023-12-29 11:02:28 浏览: 73
pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()
5星 · 资源好评率100%
以下是使用PyTorch的cat函数进行张量拼接的示例代码:
```python
import torch
# 创建两个张量
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 在最后一个维度上进行拼接
concatenated_tensor = torch.cat((tensor1, tensor2), dim=-1)
print(concatenated_tensor)
```
输出结果:
```
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
```
在上述代码中,我们首先创建了两个形状为(2, 3)的张量`tensor1`和`tensor2`,然后使用`torch.cat`函数将它们在最后一个维度上进行拼接,即`dim=-1`。最后打印输出了拼接后的张量`concatenated_tensor`的结果。
阅读全文