重采样粒子滤波python代码
时间: 2023-11-08 08:58:15 浏览: 82
以下是一个简单的重采样粒子滤波的Python代码:
```python
import numpy as np
import matplotlib.pyplot as plt
# 系统模型
def system_model(x, u):
# 状态转移方程
x_next = x + u + 0.1*np.random.randn()
# 观测方程
y = x_next + 0.1*np.random.randn()
return x_next, y
# 初始化粒子
def init_particles(N):
particles = np.zeros((N, 1))
for i in range(N):
particles[i] = np.random.randn()
return particles
# 重采样
def resample(particles, w):
N = len(particles)
new_particles = np.zeros((N, 1))
new_w = np.zeros((N, 1))
# 计算归一化权重
w_norm = w / np.sum(w)
# 生成轮盘
wheel = np.cumsum(w_norm)
# 抽样
for i in range(N):
rand = np.random.uniform()
for j in range(N):
if rand < wheel[j]:
new_particles[i] = particles[j]
new_w[i] = w_norm[j]
break
# 归一化重采样后的权重
new_w = new_w / np.sum(new_w)
return new_particles, new_w
# 粒子滤波
def particle_filter(N, u, y):
particles = init_particles(N)
w = np.zeros((N, 1))
x_est = 0
for i in range(N):
particles[i], y_pred = system_model(particles[i], u)
w[i] = np.exp(-0.5*(y - y_pred)**2 / 0.1**2)
w_norm = w / np.sum(w)
x_est = np.sum(w_norm * particles)
# 判断是否需要重采样
if 1 / np.sum(w_norm**2) < N / 2:
particles, w_norm = resample(particles, w_norm)
return x_est, particles, w_norm
# 生成模拟数据
N = 100
T = 100
u = np.zeros((T, 1))
y = np.zeros((T, 1))
for t in range(T):
if t > 0:
u[t] = 0.5*np.sin(0.1*t) + 0.1*np.random.randn()
_, y[t] = system_model(0, u[t])
# 粒子滤波
x_est_pf = np.zeros((T, 1))
for t in range(T):
x_est_pf[t], _, _ = particle_filter(N, u[t], y[t])
# 绘制结果
plt.figure()
plt.plot(y, label='Observation')
plt.plot(x_est_pf, label='Particle Filter')
plt.legend()
plt.show()
```
以上代码实现了一个基于粒子滤波的状态估计器,用于估计一个简单的一维系统的状态。在该系统中,状态转移方程是 $x_{k+1} = x_k + u_k + w_k$,其中 $w_k$ 是高斯噪声,观测方程是 $y_k = x_{k+1} + v_k$,其中 $v_k$ 也是高斯噪声。粒子滤波的步骤包括初始化粒子、预测粒子、计算权重、重采样等。在本例中,如果重采样后的有效粒子数小于粒子数的一半,就进行重采样。最终,使用粒子的加权平均值作为状态估计值。在本例中,我们生成了一组模拟数据,并将粒子滤波估计的状态值与观测值进行比较。
阅读全文