per-maddpg代码
时间: 2023-12-29 10:02:30 浏览: 318
以下是基于PyTorch实现的Per-MADDPG算法的代码示例:
```
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
# Define the Actor network
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dims=[64, 64]):
super(Actor, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(nn.Linear(state_dim, hidden_dims[0]))
for i in range(1, len(hidden_dims)):
self.layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
self.layers.append(nn.Linear(hidden_dims[-1], action_dim))
def forward(self, state):
x = state
for layer in self.layers[:-1]:
x = F.relu(layer(x))
x = torch.tanh(self.layers[-1](x))
return x
# Define the Critic network
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dims=[64, 64]):
super(Critic, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(nn.Linear(state_dim + action_dim, hidden_dims[0]))
for i in range(1, len(hidden_dims)):
self.layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
self.layers.append(nn.Linear(hidden_dims[-1], 1))
def forward(self, state, action):
x = torch.cat([state, action], dim=1)
for layer in self.layers[:-1]:
x = F.relu(layer(x))
x = self.layers[-1](x)
return x
# Define the Replay Buffer
class ReplayBuffer:
def __init__(self, max_size):
self.max_size = max_size
self.buffer = []
self.idx = 0
def add(self, state, action, reward, next_state, done):
experience = (state, action, reward, next_state, done)
if len(self.buffer) < self.max_size:
self.buffer.append(experience)
else:
self.buffer[self.idx] = experience
self.idx = (self.idx + 1) % self.max_size
def sample(self, batch_size):
samples = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in samples])
return np.stack(states), np.stack(actions), \
np.stack(rewards), np.stack(next_states), \
np.stack(dones)
# Define the Per-MADDPG agent
class PerMADDPG:
def __init__(self, state_dim, action_dim, num_agents, gamma=0.99, tau=0.01,
lr_actor=0.001, lr_critic=0.001, buffer_size=int(1e6),
batch_size=64, alpha=0.6, beta=0.4, eps=1e-5):
self.state_dim = state_dim
self.action_dim = action_dim
self.num_agents = num_agents
self.gamma = gamma
self.tau = tau
self.lr_actor = lr_actor
self.lr_critic = lr_critic
self.batch_size = batch_size
self.alpha = alpha
self.beta = beta
self.eps = eps
self.actors = [Actor(state_dim, action_dim) for _ in range(num_agents)]
self.critics = [Critic(state_dim*num_agents, action_dim*num_agents) for _ in range(num_agents)]
self.target_actors = [Actor(state_dim, action_dim) for _ in range(num_agents)]
self.target_critics = [Critic(state_dim*num_agents, action_dim*num_agents) for _ in range(num_agents)]
for i in range(num_agents):
self.target_actors[i].load_state_dict(self.actors[i].state_dict())
self.target_critics[i].load_state_dict(self.critics[i].state_dict())
self.actor_optimizers = [optim.Adam(actor.parameters(), lr=lr_actor) for actor in self.actors]
self.critic_optimizers = [optim.Adam(critic.parameters(), lr=lr_critic) for critic in self.critics]
self.replay_buffer = ReplayBuffer(max_size=buffer_size)
def act(self, states, noise=0.0):
actions = []
for i in range(self.num_agents):
state = torch.tensor(states[i], dtype=torch.float32)
action = self.actors[i](state.unsqueeze(0)).squeeze(0).detach().numpy()
action += noise * np.random.randn(self.action_dim)
actions.append(np.clip(action, -1.0, 1.0))
return actions
def update(self):
# Sample a batch of experiences from the replay buffer
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
# Convert to PyTorch tensors
states = torch.tensor(states, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.float32)
rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
next_states = torch.tensor(next_states, dtype=torch.float32)
dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
# Compute the TD error
target_actions = []
for i in range(self.num_agents):
target_actions.append(self.target_actors[i](next_states[:, i, :]))
target_actions = torch.stack(target_actions, dim=1)
target_q_values = []
for i in range(self.num_agents):
target_q_values.append(self.target_critics[i](next_states.view(-1, self.state_dim*self.num_agents), target_actions.view(-1, self.action_dim*self.num_agents)))
target_q_values = torch.stack(target_q_values, dim=1)
target_q_values = rewards[:, :, None] + self.gamma * (1 - dones[:, :, None]) * target_q_values
predicted_q_values = []
for i in range(self.num_agents):
predicted_q_values.append(self.critics[i](states.view(-1, self.state_dim*self.num_agents), actions.view(-1, self.action_dim*self.num_agents)))
predicted_q_values = torch.stack(predicted_q_values, dim=1)
td_errors = target_q_values - predicted_q_values
# Update the priorities in the replay buffer
priorities = np.abs(td_errors.detach().numpy()) ** self.alpha + self.eps
for i in range(self.batch_size):
idx = self.replay_buffer.idx - self.batch_size + i
self.replay_buffer.buffer[idx] = (states[i], actions[i], rewards[i], next_states[i], dones[i], priorities[i])
# Compute the importance-sampling weights
weights = (self.replay_buffer.max_size * priorities) ** (-self.beta)
weights /= np.max(weights)
# Update the actor and critic networks
for i in range(self.num_agents):
# Sample a minibatch of experiences from the replay buffer
idxs = np.random.randint(0, len(self.replay_buffer.buffer), size=self.batch_size)
states_mb = []
actions_mb = []
weights_mb = []
td_errors_mb = []
for j in range(self.batch_size):
state, action, reward, next_state, done, priority = self.replay_buffer.buffer[idxs[j]]
states_mb.append(state)
actions_mb.append(action)
weights_mb.append(weights[idxs[j]])
td_errors_mb.append(td_errors[j, i].item())
# Convert to PyTorch tensors
states_mb = torch.tensor(states_mb, dtype=torch.float32)
actions_mb = torch.tensor(actions_mb, dtype=torch.float32)
weights_mb = torch.tensor(weights_mb, dtype=torch.float32).unsqueeze(1)
td_errors_mb = torch.tensor(td_errors_mb, dtype=torch.float32).unsqueeze(1)
# Update the critic network
self.critic_optimizers[i].zero_grad()
predicted_q_values_mb = self.critics[i](states_mb.view(-1, self.state_dim*self.num_agents), actions_mb.view(-1, self.action_dim*self.num_agents))
critic_loss = torch.mean(weights_mb * (predicted_q_values_mb - target_q_values[:, i, None]).pow(2))
critic_loss.backward()
self.critic_optimizers[i].step()
# Update the actor network
self.actor_optimizers[i].zero_grad()
actor_loss = -torch.mean(weights_mb * td_errors_mb.detach() * self.actors[i](states_mb))
actor_loss.backward()
self.actor_optimizers[i].step()
# Update the target networks
for target_param, param in zip(self.target_actors[i].parameters(), self.actors[i].parameters()):
target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
for target_param, param in zip(self.target_critics[i].parameters(), self.critics[i].parameters()):
target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
def save(self, filename):
torch.save({
'actor_params': [actor.state_dict() for actor in self.actors],
'critic_params': [critic.state_dict() for critic in self.critics]
}, filename)
def load(self, filename):
checkpoint = torch.load(filename)
for i in range(self.num_agents):
self.actors[i].load_state_dict(checkpoint['actor_params'][i])
self.critics[i].load_state_dict(checkpoint['critic_params'][i])
self.target_actors[i].load_state_dict(checkpoint['actor_params'][i])
self.target_critics[i].load_state_dict(checkpoint['critic_params'][i])
```
在上述代码中,`Actor` 类定义了 Actor 网络,`Critic` 类定义了 Critic 网络,`ReplayBuffer` 类定义了经验回放缓存,`PerMADDPG` 类实现了 Per-MADDPG 算法。
在 `PerMADDPG` 类的 `__init__` 函数中,我们定义了模型的超参数,创建了 Actor 和 Critic 网络,以及目标网络和优化器,并初始化了经验回放缓存。
在 `act` 函数中,我们使用 Actor 网络生成动作,加入一定的高斯噪声。
在 `update` 函数中,首先从经验回放缓存中采样一批经验,计算 TD 误差,并更新缓存中的优先级。然后,计算重要性采样权重,并使用这些权重更新 Actor 和 Critic 网络。最后,更新目标网络。
最后,`save` 函数和 `load` 函数分别用于保存和加载模型的参数。
阅读全文