粒子滤波python代码
时间: 2023-10-13 08:18:13 浏览: 61
以下是一个简单的粒子滤波器的 Python 代码实现:
```
import numpy as np
import matplotlib.pyplot as plt
def particle_filter(y, num_particles, time_steps, transition_model, observation_model, resample_interval=1):
# 初始化粒子
particles = np.zeros((num_particles, time_steps))
particles[:, 0] = np.random.normal(0, 1, size=num_particles)
# 初始化权重
weights = np.ones(num_particles) / num_particles
# 递推计算粒子和权重
for t in range(1, time_steps):
# 预测步骤:从上一时刻的粒子中采样并转移
particles[:, t] = transition_model(particles[:, t-1]) + np.random.normal(0, 1, size=num_particles)
# 更新步骤:计算每个粒子的权重
weights *= observation_model(y[t], particles[:, t])
# 规范化权重
weights /= np.sum(weights)
# 每隔 resample_interval 个时间步长,进行重采样
if t % resample_interval == 0:
particles, weights = resample(particles, weights)
return particles, weights
def resample(particles, weights):
num_particles = particles.shape[0]
# 计算每个粒子的累计权重
cum_weights = np.cumsum(weights)
# 生成新的粒子索引
new_particle_indices = np.zeros(num_particles, dtype=int)
u = np.random.uniform(size=num_particles)
for i in range(num_particles):
index = np.searchsorted(cum_weights, u[i])
new_particle_indices[i] = index
# 用新的粒子索引重采样粒子和权重
new_particles = particles[new_particle_indices]
new_weights = np.ones(num_particles) / num_particles
return new_particles, new_weights
# 例子
def transition_model(x):
return 0.5 * x + 25 * x / (1 + x ** 2) + 8 * np.cos(1.2 * (t - 1))
def observation_model(y, x):
return np.exp(-0.5 * (y - x ** 2) ** 2 / 10)
t = np.arange(0, 50)
y = 2 * t + 5 * np.random.normal(size=t.shape)
num_particles = 1000
particles, weights = particle_filter(y, num_particles, len(t), transition_model, observation_model)
# 绘制粒子和真实值的对比
plt.figure()
plt.plot(t, y, 'k-', label='True Value')
plt.plot(t, np.mean(particles, axis=0), 'b-', label='Particle Filter')
plt.legend()
plt.show()
```
在这个例子中,我们使用了一个简单的非线性状态空间模型,其中状态转移模型和观测模型都是非线性的。我们使用粒子滤波器来估计状态,并将其与真实值进行比较。