求音频的STFT pytorch
时间: 2023-11-04 11:14:08 浏览: 160
您好!我了解您的问题,以下是使用 PyTorch 进行音频 STFT 的代码示例:
首先,您需要导入所需库:
``` python
import torch
from torch.nn.functional import pad
# 如果您希望可视化输出,请安装 matplotlib 库
import matplotlib.pyplot as plt
```
接下来,您需要定义一个函数来计算短时傅里叶变换(STFT):
``` python
def stft(x, fft_size, hop_size, win_size=None, window='hann'):
# 将输入转换为可用于torch.fft的复数张量
x = x.unsqueeze(-1)
x_real = torch.cat([x, torch.zeros_like(x)], dim=-1)
# 创建窗口函数(默认为汉宁窗)
if win_size is None:
win_size = fft_size
win = torch.hann_window(win_size, device=x.device) if window == 'hann' else torch.ones(win_size, device=x.device)
# 零填充输入以获得所需帧长度
pad_size = (fft_size - hop_size) // 2
x_real = pad(x_real, [pad_size, pad_size], mode='reflect')
# 计算每个帧的STFT
stft_frames = []
for i in range((x_real.shape[-2] - fft_size) // hop_size + 1):
start = i * hop_size
end = start + fft_size
frame = x_real[..., start:end]
# 加窗
frame *= win
# 计算FFT和取模
frame = torch.fft.fft(frame, fft_size, dim=-2)
frame = frame.real**2 + frame.imag**2
stft_frames.append(frame)
# 将所有帧堆叠成一张张量
stft_tensor = torch.stack(stft_frames, dim=-1)
return stft_tensor
```
这个函数的输入是一个一维张量 x,表示原始音频信号。fft_size 是要计算的FFT大小,hop_size 是每个帧之间的跳跃量。如果未指定 win_size,则使用 fft_size,窗口函数可以是 'hann' 或 'rect',默认为 'hann'。输出是一张三维张量,其形状为(freq_bins,frames,1),其中 freq_bins 是 FFT 的时间长度,frames 是帧的数量。
例如,如果您有一个名为 audio 的一维张量,采样率为 16kHz,想要在窗口大小为 1024,每隔 256 个样本(即每 16ms)计算一个帧的情况下计算 256 点的FFT,您可以这样调用函数:
``` python
sample_rate = 16000
window_size = 1024
hop_size = 256
freq_bins = window_size // 2 + 1
audio = torch.randn(sample_rate * 10) # 10秒的随机信号
stft_matrix = stft(audio, window_size, hop_size)
# 可视化stft(仅供参考)
plt.imshow(stft_matrix[:, :, 0].log2().flip(0), cmap='jet', extent=[0, 10, 0, freq_bins], aspect='auto')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.show()
```
这将计算 stft_matrix,其中 freq_bins 是 513,帧数为 (len(audio) - window_size) / hop_size + 1。
请注意,该函数实现了“重叠相加”的 STFT,这意味着每个帧之间存在一些重叠,以使输出保持平滑,并且第一个样本之前和最后一个样本之后采用“反射”零填充以避免边缘效应。
希望这可以回答您的问题!
阅读全文