pytorch的tensor拼接代码
时间: 2023-10-29 10:41:54 浏览: 80
PyTorch中Tensor的拼接与拆分的实现
5星 · 资源好评率100%
可以使用torch.cat()方法进行tensor拼接,示例代码如下:
import torch
x = torch.randn(2, 3)
y = torch.randn(2, 3)
# 在第0维上拼接
z = torch.cat([x, y], dim=0)
print(z.size()) # 输出 [4, 3]
# 在第1维上拼接
z = torch.cat([x, y], dim=1)
print(z.size()) # 输出 [2, 6]
阅读全文