在这种前提下,我想实现刚刚问你的计算每个episode的循环中计算KL散度
时间: 2024-04-24 22:26:17 浏览: 164
kldiv.zip_KL散度_kl 散度_kl-div_kl散度计算_散度
如果你希望在每个episode的循环中计算KL散度,你可以在每次更新参数之前计算KL散度。以下是修改后的代码:
```python
print('PPO2训练中...')
actor=Actor()
critic=Critic()
all_ep_r = []
for episode in range(EP_MAX):
observation = env.reset() #环境重置
buffer_s, buffer_a, buffer_r,buffer_a_logp = [], [], [],[]
reward_totle=0
kl_divergence=0
for timestep in range(EP_LEN):
if RENDER:
env.render()
action,action_logprob=actor.choose_action(observation)
observation_, reward, done, info = env.step(action)
buffer_s.append(observation)
buffer_a.append(action)
buffer_r.append((reward+8)/8) # normalize reward, find to be useful
buffer_a_logp.append(action_logprob)
observation=observation_
reward_totle+=reward
reward = (reward - reward.mean()) / (reward.std() + 1e-5)
if (timestep+1) % BATCH == 0 or timestep == EP_LEN-1:
v_observation_ = critic.get_v(observation_)
discounted_r = []
for r in buffer_r[::-1]:
v_observation_ = r + GAMMA * v_observation_
discounted_r.append(v_observation_.detach().numpy())
discounted_r.reverse()
bs, ba, br, bap = np.vstack(buffer_s), np.vstack(buffer_a), np.array(discounted_r), np.vstack(buffer_a_logp)
buffer_s, buffer_a, buffer_r, buffer_a_logp = [], [], [], []
advantage = critic.learn(bs, br) # critic部分更新
# 在更新参数之前计算KL散度
with torch.no_grad():
new_action_logprob = actor.evaluate_actions(bs, ba)
kl_divergence = torch.nn.functional.kl_div(new_action_logprob, bap).mean().item()
actor.learn(bs, ba, advantage, bap) # actor部分更新
actor.update_oldpi() # pi-new的参数赋给pi-old
if episode == 0:
all_ep_r.append(reward_totle)
else:
all_ep_r.append(all_ep_r[-1] * 0.9 + reward_totle * 0.1)
print("\rEp: {} |rewards: {}|kl_divergence: {}".format(episode, reward_totle, kl_divergence), end="")
if episode % 50 == 0 and episode > 100:
save_data = {'net': actor.old_pi.state_dict(), 'opt': actor.optimizer.state_dict(), 'i': episode}
torch.save(save_data, "E:\PPO2_model_actor.pth")
save_data = {'net': critic.critic_v.state_dict(), 'opt': critic.optimizer.state_dict(), 'i': episode}
torch.save(save_data, "E:\PPO2_model_critic.pth")
env.close()
plt.plot(np.arange(len(all_ep_r)), all_ep_r)
plt.xlabel('Episode')
plt.ylabel('Moving averaged episode reward')
plt.show()
```
在修改后的代码中,我们将计算KL散度的部分移到了参数更新之前,在每个batch的循环内进行计算。通过将计算KL散度的代码放在参数更新之前,你可以确保每个episode的循环中都计算了KL散度,并将其打印出来。
阅读全文