用pytorch来拼接三维张量
时间: 2023-07-17 12:59:46 浏览: 156
在PyTorch中,可以使用`torch.cat()`函数来拼接三维张量。`torch.cat()`函数的用法如下:
```python
torch.cat(tensors, dim=0, *, out=None)
```
其中,`tensors`是要拼接的张量序列,`dim`是指定拼接的维度。下面是一个示例代码,演示如何拼接三维张量:
```python
import torch
# 创建三个三维张量
tensor1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
tensor2 = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
tensor3 = torch.tensor([[[17, 18], [19, 20]], [[21, 22], [23, 24]]])
# 在第一个维度上拼接
concatenated_tensor = torch.cat((tensor1, tensor2, tensor3), dim=0)
print(concatenated_tensor)
```
输出结果:
```
tensor([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16]],
[[17, 18],
[19, 20]],
[[21, 22],
[23, 24]]])
```
在上述代码中,我们首先创建了三个形状为(2, 2, 2)的三维张量`tensor1`、`tensor2`和`tensor3`,然后使用`torch.cat()`函数在第一个维度上将它们拼接在一起,即`dim=0`。最后打印输出了拼接后的三维张量`concatenated_tensor`的结果。
阅读全文