class STFT(torch.nn.Module):类代码详解
时间: 2023-07-16 22:14:46 浏览: 118
pytorch 中的重要模块化接口nn.Module的使用
`STFT` 是一个 PyTorch 模块,用于计算短时傅里叶变换(Short-Time Fourier Transform,STFT),是一种常用的信号处理技术。下面是这个模块的详细解释:
```python
class STFT(torch.nn.Module):
def __init__(self, filter_length=2048, hop_length=512, win_length=None,
window='hann', center=True, pad_mode='reflect', freeze_parameters=True):
super().__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
if win_length is None:
win_length = filter_length
self.win_length = win_length
self.window = get_window(window, win_length)
# Create filter kernel
fft_basis = np.fft.fft(np.eye(filter_length))
kernel = np.concatenate([np.real(fft_basis[:filter_length // 2 + 1, :]),
np.imag(fft_basis[:filter_length // 2 + 1, :])], 0)
self.register_buffer('kernel', torch.tensor(kernel, dtype=torch.float32))
# Freeze parameters
if freeze_parameters:
for name, param in self.named_parameters():
param.requires_grad = False
def forward(self, waveform):
assert (waveform.dim() == 1)
# Pad waveform
if self.center:
waveform = nn.functional.pad(waveform.unsqueeze(0),
(self.filter_length // 2, self.filter_length // 2),
mode='constant',
value=0)
else:
waveform = nn.functional.pad(waveform.unsqueeze(0),
(self.filter_length - self.hop_length, 0),
mode='constant',
value=0)
# Window waveform
if waveform.shape[-1] < self.win_length:
waveform = nn.functional.pad(waveform, (self.win_length - waveform.shape[-1], 0),
mode='constant',
value=0)
waveform = waveform.squeeze(0)
if self.window.device != waveform.device:
self.window = self.window.to(waveform.device)
windowed_waveform = waveform * self.window
# Pad for linear convolution
if self.center:
windowed_waveform = nn.functional.pad(windowed_waveform,
(self.filter_length // 2, self.filter_length // 2),
mode='constant',
value=0)
else:
windowed_waveform = nn.functional.pad(windowed_waveform,
(self.filter_length - self.hop_length, 0),
mode='constant',
value=0)
# Perform convolution
fft = torch.fft.rfft(windowed_waveform.unsqueeze(0), dim=1)
fft = torch.cat((fft.real, fft.imag), dim=1)
output = torch.matmul(fft, self.kernel)
# Remove redundant frequencies
output = output[:, :self.filter_length // 2 + 1, :]
return output
```
- `__init__` 方法:构造方法,用于初始化模块的各个参数。其中,`filter_length` 表示 STFT 的滤波器长度,`hop_length` 表示 STFT 的帧移(即相邻帧之间的采样点数),`win_length` 表示 STFT 的窗函数长度,`window` 是指定的窗函数类型(默认为汉宁窗),`center` 表示是否需要在信号两端填充 0 以保证 STFT 的中心位置与输入信号的中心位置对齐,`pad_mode` 是指定填充方式(默认为反射填充),`freeze_parameters` 表示是否需要冻结模块的参数。
- `forward` 方法:前向传播方法,用于计算输入信号的 STFT。其中,`waveform` 表示输入信号。首先,根据 `center` 和 `pad_mode` 对输入信号进行填充和窗函数处理,然后进行线性卷积,最后通过傅里叶变换计算 STFT。返回的 `output` 是一个张量,表示 STFT 系数。
阅读全文