torch.transpose()与 torch.permute()区别
时间: 2023-12-02 19:58:51 浏览: 142
torch.transpose() 和 torch.permute() 都是 PyTorch 中用于变换张量维度的函数,它们的区别在于:
torch.transpose() 函数是对张量的转置操作,只支持二维(矩阵)和三维(批量矩阵)张量的转置。对于二维张量,可以使用该函数将其转置;对于三维张量,可以指定转置的两个维度进行转置。例如:
```
import torch
x = torch.randn(3, 4)
print(x)
# tensor([[ 2.0514, 1.8634, -0.0375, -1.1761],
# [-1.3048, 0.0905, 0.3555, 0.2998],
# [-0.5564, 0.4689, -0.8771, 0.9331]])
y = torch.transpose(x, 0, 1)
print(y)
# tensor([[ 2.0514, -1.3048, -0.5564],
# [ 1.8634, 0.0905, 0.4689],
# [-0.0375, 0.3555, -0.8771],
# [-1.1761, 0.2998, 0.9331]])
z = torch.randn(2, 3, 4)
print(z)
# tensor([[[ 0.4673, 0.1136, -1.3368, -0.1990],
# [-0.5333, 0.0768, -0.6232, -0.8085],
# [-1.2438, 1.8073, -0.6008, -1.0195]],
#
# [[ 0.4232, -0.2118, 1.2122, 0.7345],
# [-1.3797, -0.3909, -0.2965, -1.3328],
# [-0.8473, -0.6902, 0.1941, -0.8746]]])
w = torch.transpose(z, 0, 1)
print(w)
# tensor([[[ 0.4673, 0.1136, -1.3368, -0.1990],
# [ 0.4232, -0.2118, 1.2122, 0.7345]],
#
# [[-0.5333, 0.0768, -0.6232, -0.8085],
# [-1.3797, -0.3909, -0.2965, -1.3328]],
#
# [[-1.2438, 1.8073, -0.6008, -1.0195],
# [-0.8473, -0.6902, 0.1941, -0.8746]]])
```
torch.permute() 函数可以对张量的任意维度进行重新排列,可以对任意维度的张量进行操作。例如:
```
import torch
x = torch.randn(2, 3, 4)
print(x)
# tensor([[[-1.1819, 0.6551, 0.3769, -0.7894],
# [ 0.2819, -0.5533, -0.2477, -0.8851],
# [-0.6033, -1.4273, -0.2664, -0.5853]],
#
# [[-1.0286, -1.7711, -0.5401, 0.5417],
# [-0.4733, 2.4179, 0.0829, 0.7627],
# [-0.4668, 0.0573, 0.2071, -0.5846]]])
y = x.permute(2, 0, 1)
print(y)
# tensor([[[-1.1819, 0.2819, -0.6033],
# [-1.0286, -0.4733, -0.4668]],
#
# [[ 0.6551, -0.5533, -1.4273],
# [-1.7711, 2.4179, 0.0573]],
#
# [[ 0.3769, -0.2477, -0.2664],
# [-0.5401, 0.0829, 0.2071]],
#
# [[-0.7894, -0.8851, -0.5853],
# [ 0.5417, 0.7627, -0.5846]]])
```
可以看到,torch.permute() 函数可以对张量的任意维度进行重新排列,而 torch.transpose() 函数只能对二维和三维张量进行转置。
阅读全文