ppo代码的pytorch实现
时间: 2023-09-08 12:14:26 浏览: 169
以下是一个简单的PPO算法的PyTorch实现,主要参考了OpenAI baselines中的实现。这里以CartPole-v1为例:
```python
import torch
import torch.nn as nn
from torch.distributions import Categorical
import gym
class ActorCritic(nn.Module):
def __init__(self, obs_shape, action_space):
super(ActorCritic, self).__init__()
self.observation_space = obs_shape
self.action_space = action_space
self.actor_fc1 = nn.Linear(obs_shape[0], 64)
self.actor_fc2 = nn.Linear(64, action_space.n)
self.critic_fc1 = nn.Linear(obs_shape[0], 64)
self.critic_fc2 = nn.Linear(64, 1)
self.log_probs = []
self.values = []
self.rewards = []
self.masks = []
def act(self, obs):
actor_x = torch.tanh(self.actor_fc1(obs))
action_scores = self.actor_fc2(actor_x)
dist = Categorical(logits=action_scores)
action = dist.sample()
self.log_probs.append(dist.log_prob(action))
return action.item()
def evaluate(self, obs):
actor_x = torch.tanh(self.actor_fc1(obs))
action_scores = self.actor_fc2(actor_x)
dist = Categorical(logits=action_scores)
action = dist.sample()
log_prob = dist.log_prob(action)
critic_x = torch.tanh(self.critic_fc1(obs))
value = self.critic_fc2(critic_x)
self.log_probs.append(log_prob)
self.values.append(value)
return action.item(), value.item()
def clear_memory(self):
del self.log_probs[:]
del self.values[:]
del self.rewards[:]
del self.masks[:]
class PPO:
def __init__(self, env_name, batch_size=64, gamma=0.99, clip_param=0.2, ppo_epoch=10, lr=3e-4, eps=1e-5):
self.env = gym.make(env_name)
self.obs_space = self.env.observation_space
self.act_space = self.env.action_space
self.clip_param = clip_param
self.ppo_epoch = ppo_epoch
self.batch_size = batch_size
self.gamma = gamma
self.eps = eps
self.lr = lr
self.net = ActorCritic(self.obs_space.shape, self.act_space)
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr, eps=self.eps)
self.net.train()
def get_batch(self):
obs = self.obs_buf[np.asarray(self.batch_ids)]
actions = self.act_buf[np.asarray(self.batch_ids)]
rewards = self.rew_buf[np.asarray(self.batch_ids)]
dones = self.done_buf[np.asarray(self.batch_ids)]
next_obs = self.obs_buf[np.asarray(self.batch_ids) + 1]
masks = 1 - dones.astype(np.float32)
return obs, actions, rewards, next_obs, masks
def learn(self, obs, actions, rewards, next_obs, masks):
obs = torch.tensor(obs, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.float32)
rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
masks = torch.tensor(masks, dtype=torch.float32).unsqueeze(1)
next_obs = torch.tensor(next_obs, dtype=torch.float32)
with torch.no_grad():
_, next_value = self.net.evaluate(next_obs)
advantage = rewards + self.gamma * masks * next_value - self.net.values[-1]
returns = []
gae = 0
lambda_ = 0.95
for i in reversed(range(len(rewards))):
delta = rewards[i] + self.gamma * masks[i] * self.net.values[i + 1] - self.net.values[i]
gae = delta + self.gamma * masks[i] * lambda_ * gae
returns.insert(0, gae + self.net.values[i])
returns = torch.tensor(returns, dtype=torch.float32)
for _ in range(self.ppo_epoch):
for ind in BatchSampler(SubsetRandomSampler(range(self.batch_size)), self.batch_size, False):
log_prob, value = self.net.evaluate(obs[ind])
ratio = torch.exp(log_prob - self.net.log_probs[ind])
adv = advantage[ind]
surr1 = ratio * adv
surr2 = torch.clamp(ratio, 1 - self.clip_param, 1 + self.clip_param) * adv
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = (returns[ind] - value).pow(2).mean()
loss = actor_loss + 0.5 * critic_loss
# optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.net.clear_memory()
def run(self, max_iter=10000):
obs = self.env.reset()
episode_reward = 0
for i in range(max_iter):
action = self.net.act(torch.tensor(obs, dtype=torch.float32))
next_obs, reward, done, _ = self.env.step(action)
episode_reward += reward
self.net.rewards.append(reward)
self.net.masks.append(1 - done)
obs = next_obs
if done:
obs = self.env.reset()
self.net.clear_memory()
if i % self.batch_size == 0 and i != 0:
self.learn(*self.get_batch())
if i % 100 == 0 and i != 0:
print('Episode {}, Reward: {:.2f}'.format(i, episode_reward / 100))
episode_reward = 0
```
上述代码中,我们首先定义了一个ActorCritic类,它包括一个Actor和一个Critic。Actor根据当前状态输出动作的概率分布,并根据分布进行采样;Critic则输出当前状态的价值。在PPO算法中,我们需要同时更新Actor和Critic。PPO算法的核心在于计算Advantage,可以参考第一篇回答中的解释。此外,我们还需要使用一个BatchSampler对数据进行采样。在run函数中,我们首先根据当前状态获取一个动作,然后执行该动作,并将相关的信息存储在ActorCritic类中。如果当前episode结束,我们则清空ActorCritic类中的信息,并重置环境。如果当前步数可以被batch_size整除,我们则进行PPO算法的更新。
阅读全文