torch.transpose()与 torch.permute()例程
时间: 2023-11-10 22:09:33 浏览: 50
torch.transpose() 和 torch.permute() 都是 PyTorch 中用于转置张量的函数,但它们的用法略有不同。
torch.transpose() 用于交换张量的两个维度。例如,如果原始张量的形状为 (2,3,4),则可以使用 torch.transpose() 将其转置为 (2,4,3):
```python
import torch
x = torch.randn(2, 3, 4)
print(x.shape) # 输出 (2, 3, 4)
# 转置 x 的后两个维度
y = torch.transpose(x, 1, 2)
print(y.shape) # 输出 (2, 4, 3)
```
注意,torch.transpose() 只能交换两个维度,并且这两个维度的长度必须相同。
torch.permute() 则可以任意交换张量的维度。例如,可以使用 torch.permute() 将形状为 (2,3,4) 的张量转置为 (3,2,4):
```python
import torch
x = torch.randn(2, 3, 4)
print(x.shape) # 输出 (2, 3, 4)
# 使用 permute() 交换 x 的第一和第二个维度
y = x.permute(1, 0, 2)
print(y.shape) # 输出 (3, 2, 4)
```
需要注意的是,torch.permute() 操作会返回一个新的张量,而不是直接修改原始张量。如果需要在原始张量上进行操作,则需要使用 in-place 操作符“_”:
```python
import torch
x = torch.randn(2, 3, 4)
print(x.shape) # 输出 (2, 3, 4)
# 使用 in-place 操作符交换 x 的前两个维度
x.permute_(1, 0, 2)
print(x.shape) # 输出 (3, 2, 4)
```