只包含tensor类型不用numpy 编写一个可以在pytorch框架下运行的EMD经验模态分解的类的python代码 并运行画出分量图
时间: 2024-05-05 15:15:03 浏览: 132
emd经验模态分解程序代码
4星 · 用户满意度95%
以下是一个使用PyTorch框架编写的经验模态分解(EMD)的类:
```python
import torch
import matplotlib.pyplot as plt
class EMD:
def __init__(self, max_imfs=10, eps=1e-5, device='cpu'):
self.max_imfs = max_imfs
self.eps = eps
self.device = device
def get_envelopes(self, x):
upper_env = torch.zeros_like(x, device=self.device)
lower_env = torch.zeros_like(x, device=self.device)
# Find upper and lower envelopes
for i in range(1, x.shape[-1]-1):
if (x[...,i] > x[...,i-1]) and (x[...,i] > x[...,i+1]):
upper_env[...,i] = x[...,i]
if (x[...,i] < x[...,i-1]) and (x[...,i] < x[...,i+1]):
lower_env[...,i] = x[...,i]
# Interpolate missing values
upper_env = self.interpolate_nans(upper_env)
lower_env = self.interpolate_nans(lower_env)
return upper_env, lower_env
def interpolate_nans(self, x):
nans = torch.isnan(x)
x[nans] = 0
# Set first and last points to zero to ensure interpolation
x[...,0] = 0
x[...,-1] = 0
# Interpolate missing values
idx = torch.arange(x.shape[-1], device=self.device)
x[nans] = torch.stack([torch.interp(idx[nans[...,i]], idx[~nans[...,i]], x[...,i][~nans[...,i]]) for i in range(x.shape[-2])], dim=-1)
return x
def get_imf(self, x):
imf = x.clone()
h = x.clone()
# Extract IMF
for i in range(self.max_imfs):
# Find local maxima and minima
upper_env, lower_env = self.get_envelopes(h)
mean_env = (upper_env + lower_env) / 2
# Calculate difference between mean envelope and signal
h = imf - mean_env
h = self.interpolate_nans(h)
# Check if IMF is valid
if torch.abs(h).max() < self.eps:
break
else:
imf = h.clone()
return imf
def decompose(self, x):
imfs = []
for i in range(self.max_imfs):
# Extract IMF
imf = self.get_imf(x)
# Subtract IMF from signal
x = x - imf
# Add IMF to list
imfs.append(imf)
# Check if decomposition is complete
if torch.abs(x).max() < self.eps:
break
return imfs
def plot_imfs(self, imfs):
n_imfs = len(imfs)
# Set up figure
fig, axs = plt.subplots(n_imfs, 1, figsize=(10, 4*n_imfs))
# Plot IMFs
for i in range(n_imfs):
axs[i].plot(imfs[i].cpu())
axs[i].set_title(f'IMF {i+1}')
plt.show()
```
下面是一个使用示例:
```python
# Generate test data
t = torch.linspace(0, 1, 1000)
x = torch.sin(2 * torch.pi * 5 * t) + torch.sin(2 * torch.pi * 10 * t) + torch.sin(2 * torch.pi * 20 * t)
# Perform EMD
emd = EMD(max_imfs=3)
imfs = emd.decompose(x)
# Plot IMFs
emd.plot_imfs(imfs)
```
这将生成一个包含三个IMF的图。
阅读全文