pytorch 数据重排
时间: 2024-12-28 16:22:44 浏览: 12
### 如何在 PyTorch 中对数据进行重排操作
#### 使用 `permute()` 方法改变维度顺序
为了调整张量的维度顺序而不改变其内容,可以使用 `torch.Tensor.permute(*dims)` 函数。此函数返回一个新的视图对象,其中输入张量的给定维度被置换。
```python
import torch
# 创建一个形状为 [2, 3, 5, 7] 的零填充张量作为例子
x = torch.zeros([2, 3, 5, 7])
# 将原始张量的第一个维度(batch size)移到最后位置
y = x.permute(1, 2, 3, 0)
print(f'Original shape: {x.shape}')
print(f'Rearranged shape after permute(): {y.shape}')
```
上述代码创建了一个四维张量并调用了 `.permute()` 来交换这些轴的位置[^1]。
#### 利用 `reshape()` 或者 `view()` 改变张量尺寸
当需要更改张量的整体大小而不仅仅是重新排列现有维度时,则应该考虑使用 `torch.reshape(input, shape)` 或者更常用的 `tensor.view(shape)` 方法来实现这一点。这两个方法都可以用于扁平化多维数组或将一维向量转换成矩阵等形式。
需要注意的是,在某些情况下,`view()` 可能会失败因为它依赖于底层存储是否连续;此时应先通过`.contiguous()`.确保内存布局适合再尝试重塑。
```python
# 假设我们有一个形状为 [2, 3, 5, 7] 的张量 z
z = torch.randn((2, 3, 5, 7))
# 把它变成一个二维张量,保持 batch_size 不变
flattened_z = z.view(z.size(0), -1)
print(f'Shape before reshaping with view(): {z.shape}')
print(f'Shape after flattening using view(): {flattened_z.shape}')
# 如果遇到错误提示 "non-contiguous" ,则需加上 .contiguous()
safe_flatten = z.contiguous().view(-1)
```
这里展示了如何利用 `view()` 和 `reshape()` 对张量进行变形处理。
#### 结合两者完成复杂变换
有时可能既想要改变维度顺序又希望修改具体尺度。这时就可以把两种技术结合起来应用:
```python
original_tensor = torch.rand((2, 3, 5, 7))
reordered_and_resized = original_tensor.permute(0, 2, 3, 1).reshape(original_tensor.size(0), -1)
print(f'Final shape after both operations: {reordered_and_resized.shape}')
```
这段脚本首先改变了原张量内部各层之间的相对次序,接着将其余三个维度合并到一起形成新的单一层级结构。
阅读全文