时间序列数据增强代码
时间: 2023-12-20 19:30:08 浏览: 34
时间序列数据增强是指通过对原始时间序列数据进行一系列变换,生成新的时间序列数据,以扩充数据集,提高模型的泛化能力。以下是一个Python的时间序列数据增强代码示例:
```python
import numpy as np
from scipy import signal
def time_series_augmentation(data, method='stretch', scale=0.1, window_size=5):
"""
对时间序列数据进行增强
:param data: 原始时间序列数据
:param method: 增强方法,包括'stretch'(时间拉伸),'squeeze'(时间压缩),'shift'(时间平移),'flip'(时间翻转),'noise'(加噪声),'filter'(滤波)
:param scale: 增强程度,对于'stretch'和'squeeze',表示时间缩放比例;对于'shift',表示时间平移量;对于'noise',表示噪声强度;对于'filter',表示滤波器截止频率
:param window_size: 滤波器窗口大小,仅对'filter'方法有效
:return: 增强后的时间序列数据
"""
if method == 'stretch':
# 时间拉伸
stretch_scale = 1 + scale * np.random.uniform(-1, 1)
new_length = int(len(data) * stretch_scale)
return signal.resample(data, new_length)
elif method == 'squeeze':
# 时间压缩
squeeze_scale = 1 - scale * np.random.uniform(0, 1)
new_length = int(len(data) * squeeze_scale)
return signal.resample(data, new_length)
elif method == 'shift':
# 时间平移
shift = int(scale * len(data) * np.random.uniform(-1, 1))
return np.roll(data, shift)
elif method == 'flip':
# 时间翻转
return np.flip(data)
elif method == 'noise':
# 加噪声
noise = np.random.normal(0, scale, len(data))
return data + noise
elif method == 'filter':
# 滤波
b, a = signal.butter(4, scale, 'lowpass')
return signal.filtfilt(b, a, data)
else:
raise ValueError('Invalid method')
# 示例
data = np.sin(np.arange(0, 10, 0.1))
aug_data = time_series_augmentation(data, method='stretch', scale=0.2)
```