torch.concatenate
时间: 2023-10-30 15:06:18 浏览: 136
PyTorch的torch.cat用法
5星 · 资源好评率100%
torch.concatenate 是 PyTorch 中用于拼接张量的函数。它可以将多个张量沿指定的维度进行拼接。你可以传入一个包含多个张量的列表,并指定拼接的维度。例如:
```python
import torch
# 创建两个张量
x1 = torch.tensor([[1, 2, 3],
[4, 5, 6]])
x2 = torch.tensor([[7, 8, 9],
[10, 11, 12]])
# 沿着第 0 维度拼接
result = torch.cat([x1, x2], dim=0)
print(result)
# 输出:
# tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])
# 沿着第 1 维度拼接
result = torch.cat([x1, x2], dim=1)
print(result)
# 输出:
# tensor([[ 1, 2, 3, 7, 8, 9],
# [ 4, 5, 6, 10, 11, 12]])
```
这样,你就可以使用 torch.concatenate 来拼接多个张量。希望这个回答对你有帮助!如果你还有其他问题,请随时提问。
阅读全文