torch.rfft2用法详解
时间: 2023-10-06 13:08:24 浏览: 962
torch.rfft2 是 PyTorch 中的一个函数,用于进行二维实数输入的快速傅里叶变换(FFT)。该函数的返回值是一个复数张量,包含了输入数据的傅里叶变换结果。以下是该函数的详细用法。
函数定义:
```python
torch.rfft2(input, signal_ndim=2, normalized=False, onesided=True)
```
参数说明:
- `input`:输入张量,数据类型为实数,形状为 `(batch_size, channel, height, width)`。
- `signal_ndim`:表示输入张量的信号维度,默认值为 2,即二维信号。
- `normalized`:表示是否进行归一化,即除以 Fourier 变换的系数。默认值为 False,表示不进行归一化。
- `onesided`:表示是否只保留输出的单边频谱,即对称的频率部分是否只保留一半。默认值为 True,表示只保留单边频谱。
返回值:
- 输出张量,数据类型为复数,形状为 `(batch_size, channel, height, width // 2 + 1, 2)`。其中最后一个维度表示实部和虚部。
使用示例:
```python
import torch
# 定义输入数据
input = torch.randn(1, 3, 32, 32)
# 进行二维实数输入的快速傅里叶变换
output = torch.rfft2(input)
# 输出结果
print(output.shape) # torch.Size([1, 3, 17, 17, 2])
```
上述代码中,定义了一个形状为 `(1, 3, 32, 32)` 的输入张量,其中 1 表示 batch_size,3 表示通道数,32 表示高度和宽度。然后使用 `torch.rfft2` 函数对输入数据进行二维实数输入的快速傅里叶变换,得到一个形状为 `(1, 3, 17, 17, 2)` 的输出张量。其中 17 表示经过 FFT 变换后的频域大小,2 表示实部和虚部。
阅读全文