torch.cat函数定义
时间: 2023-05-04 07:06:36 浏览: 57
torch.cat函数是PyTorch中的一个函数,它用于将多个张量按照指定的维度进行拼接,生成一个新的张量。
具体来说,torch.cat函数接受两个输入参数,第一个参数为要拼接的张量列表,第二个参数为拼接的维度。例如,如果要将两个形状为(3,4)的张量按照第0维拼接起来,可以使用以下代码:
```
import torch
a = torch.randn(3, 4)
b = torch.randn(3, 4)
c = torch.cat([a, b], dim=0)
print(a)
print(b)
print(c)
```
输出结果为:
```
tensor([[ 0.1207, 2.0490, 0.3733, 1.3520],
[ 0.0409, -0.7710, -0.0905, -0.7648],
[ 0.2075, -0.2875, 1.3011, 0.4901]])
tensor([[ 0.5455, -0.1474, -0.0314, -0.7324],
[ 1.1684, -0.5282, 1.1644, -0.2164],
[-0.4101, -0.3484, 0.9460, -0.5985]])
tensor([[ 0.1207, 2.0490, 0.3733, 1.3520],
[ 0.0409, -0.7710, -0.0905, -0.7648],
[ 0.2075, -0.2875, 1.3011, 0.4901],
[ 0.5455, -0.1474, -0.0314, -0.7324],
[ 1.1684, -0.5282, 1.1644, -0.2164],
[-0.4101, -0.3484, 0.9460, -0.5985]])
```
可以看到,torch.cat函数将两个(3,4)的张量按照第0维拼接成了一个(6,4)的张量。
需要注意的是,拼接的张量在指定维度以外的维度必须完全一致,否则会报错。另外,torch.cat函数并不会改变原始张量的值,而是返回一个新的张量。