深度强化学习PPO模型代码
时间: 2023-09-16 14:15:04 浏览: 79
深度强化学习PPO算法(python)
4星 · 用户满意度95%
以下是使用PyTorch实现的PPO模型代码,供参考:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super(ActorCritic, self).__init__()
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
action_probs = self.actor(state)
state_value = self.critic(state)
return action_probs, state_value
class PPO:
def __init__(self, state_dim, action_dim, hidden_dim, lr, gamma, K_epochs, eps_clip):
self.gamma = gamma
self.K_epochs = K_epochs
self.eps_clip = eps_clip
self.policy = ActorCritic(state_dim, action_dim, hidden_dim).to(device)
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
def get_action(self, state):
state = torch.FloatTensor(state).to(device)
action_probs, _ = self.policy(state)
dist = Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action)
return action.item(), log_prob.item()
def update(self, memory):
states = torch.FloatTensor(memory.states).to(device)
actions = torch.FloatTensor(memory.actions).unsqueeze(1).to(device)
old_log_probs = torch.FloatTensor(memory.log_probs).unsqueeze(1).to(device)
returns = torch.FloatTensor(memory.returns).unsqueeze(1).to(device)
advantages = torch.FloatTensor(memory.advantages).unsqueeze(1).to(device)
for _ in range(self.K_epochs):
action_probs, state_values = self.policy(states)
dist = Categorical(action_probs)
log_probs = dist.log_prob(actions)
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = F.smooth_l1_loss(state_values, returns)
loss = actor_loss + 0.5 * critic_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
memory.clear_memory()
class Memory:
def __init__(self):
self.states = []
self.actions = []
self.log_probs = []
self.rewards = []
self.dones = []
self.returns = []
self.advantages = []
def add(self, state, action, log_prob, reward, done):
self.states.append(state)
self.actions.append(action)
self.log_probs.append(log_prob)
self.rewards.append(reward)
self.dones.append(done)
def calculate_returns(self, next_state, gamma):
next_state = torch.FloatTensor(next_state).to(device)
_, next_value = agent.policy(next_state)
next_value = next_value.detach().cpu().numpy()[0]
returns = []
discounted_reward = next_value
for reward, done in zip(reversed(self.rewards), reversed(self.dones)):
if done:
discounted_reward = 0
discounted_reward = reward + gamma * discounted_reward
returns.insert(0, discounted_reward)
self.returns = returns
self.advantages = returns - self.values()
def values(self):
states = torch.FloatTensor(self.states).to(device)
_, values = agent.policy(states)
return values.detach().cpu().numpy().flatten()
def clear_memory(self):
self.states = []
self.actions = []
self.log_probs = []
self.rewards = []
self.dones = []
self.returns = []
self.advantages = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
agent = PPO(state_dim, action_dim, hidden_dim, lr, gamma, K_epochs, eps_clip)
memory = Memory()
for i in range(num_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
action, log_prob = agent.get_action(state)
next_state, reward, done, _ = env.step(action)
memory.add(state, action, log_prob, reward, done)
episode_reward += reward
state = next_state
memory.calculate_returns(next_state, gamma)
agent.update(memory)
```
阅读全文