MADDPG的pytorch代码案例

时间: 2023-06-30 12:07:16 浏览: 166
以下是使用 PyTorch 实现的 MADDPG 代码示例: ``` import torch import torch.nn as nn import torch.optim as optim import numpy as np device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 定义 Actor 和 Critic 网络 class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim): super(Actor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, action_dim) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) x = torch.tanh(self.fc3(x)) return x class Critic(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim): super(Critic, self).__init__() self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, 1) def forward(self, state, action): x = torch.cat([state, action], dim=1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x # 定义 MADDPG 算法类 class MADDPG: def __init__(self, n_agents, state_dim, action_dim, hidden_dim, discount_factor=0.99, tau=0.01, critic_lr=0.001, actor_lr=0.001): self.n_agents = n_agents self.state_dim = state_dim self.action_dim = action_dim self.hidden_dim = hidden_dim self.discount_factor = discount_factor self.tau = tau self.critic_local = [Critic(state_dim * n_agents, action_dim * n_agents, hidden_dim).to(device) for i in range(n_agents)] self.critic_target = [Critic(state_dim * n_agents, action_dim * n_agents, hidden_dim).to(device) for i in range(n_agents)] self.actor_local = [Actor(state_dim, action_dim, hidden_dim).to(device) for i in range(n_agents)] self.actor_target = [Actor(state_dim, action_dim, hidden_dim).to(device) for i in range(n_agents)] self.critic_optim = [optim.Adam(self.critic_local[i].parameters(), lr=critic_lr) for i in range(n_agents)] self.actor_optim = [optim.Adam(self.actor_local[i].parameters(), lr=actor_lr) for i in range(n_agents)] self.memory = ReplayBuffer() def act(self, state): actions = [] for i in range(self.n_agents): state_tensor = torch.tensor(state[i], dtype=torch.float32).unsqueeze(0).to(device) action_tensor = self.actor_local[i](state_tensor).detach().cpu().numpy()[0] actions.append(action_tensor) return np.array(actions) def step(self, state, action, reward, next_state, done): self.memory.add(state, action, reward, next_state, done) if len(self.memory) > BATCH_SIZE: experiences = self.memory.sample() self.learn(experiences) def learn(self, experiences): states, actions, rewards, next_states, dones = experiences for i in range(self.n_agents): states_i = states.reshape(-1, self.state_dim)[i::self.n_agents] actions_i = actions.reshape(-1, self.action_dim)[i::self.n_agents] rewards_i = rewards[:, i].reshape(-1, 1) next_states_i = next_states.reshape(-1, self.state_dim)[i::self.n_agents] dones_i = dones[:, i].reshape(-1, 1) # 计算 Q_target actions_next = [] for j in range(self.n_agents): next_states_j = next_states.reshape(-1, self.state_dim)[j::self.n_agents] action_next_j = self.actor_target[j](next_states_j).detach().cpu().numpy() actions_next.append(action_next_j) actions_next = np.stack(actions_next).transpose() q_next = self.critic_target[i](torch.tensor(next_states_i, dtype=torch.float32).to(device), torch.tensor(actions_next, dtype=torch.float32).to(device)) q_target_i = rewards_i + (self.discount_factor * q_next * (1 - dones_i)) # 计算 Critic loss q_local_i = self.critic_local[i](torch.tensor(states_i, dtype=torch.float32).to(device), torch.tensor(actions_i, dtype=torch.float32).to(device)) critic_loss_i = nn.MSELoss()(q_local_i, q_target_i.detach()) # 更新 Critic 网络 self.critic_optim[i].zero_grad() critic_loss_i.backward() self.critic_optim[i].step() # 计算 Actor loss actions_pred = [] for j in range(self.n_agents): states_j = states.reshape(-1, self.state_dim)[j::self.n_agents] actions_pred_j = self.actor_local[j](torch.tensor(states_j, dtype=torch.float32).to(device)) if j == i: actions_pred_i = actions_pred_j else: actions_pred.append(actions_pred_j.detach().cpu().numpy()) actions_pred.append(actions_pred_i.detach().cpu().numpy()) actions_pred = np.stack(actions_pred).transpose() actor_loss_i = -self.critic_local[i](torch.tensor(states_i, dtype=torch.float32).to(device), torch.tensor(actions_pred, dtype=torch.float32).to(device)).mean() # 更新 Actor 网络 self.actor_optim[i].zero_grad() actor_loss_i.backward() self.actor_optim[i].step() # 软更新 Critic 和 Actor 目标网络 self.soft_update(self.critic_local[i], self.critic_target[i], self.tau) self.soft_update(self.actor_local[i], self.actor_target[i], self.tau) def soft_update(self, local_model, target_model, tau): for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) # 定义 ReplayBuffer 类 class ReplayBuffer: def __init__(self, buffer_size=int(1e6)): self.buffer_size = buffer_size self.buffer = [] self.position = 0 def add(self, state, action, reward, next_state, done): if len(self.buffer) < self.buffer_size: self.buffer.append(None) self.buffer[self.position] = (state, action, reward, next_state, done) self.position = (self.position + 1) % self.buffer_size def sample(self, batch_size=128): batch = zip(*random.sample(self.buffer, batch_size)) return [torch.tensor(item, dtype=torch.float32).to(device) for item in batch] def __len__(self): return len(self.buffer) # 定义训练函数 def train(agent, env, n_episodes=5000, max_t=1000): scores = [] for i_episode in range(1, n_episodes+1): env_info = env.reset(train_mode=True)[brain_name] states = env_info.vector_observations agent.reset() score = np.zeros(num_agents) for t in range(max_t): actions = agent.act(states) env_info = env.step(actions)[brain_name] next_states = env_info.vector_observations rewards = env_info.rewards dones = env_info.local_done agent.step(states, actions, rewards, next_states, dones) states = next_states score += rewards if np.any(dones): break scores.append(np.max(score)) print('\rEpisode {}\tScore: {:.2f}'.format(i_episode, np.max(score)), end="") if i_episode % 100 == 0: print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores[-100:]))) if np.mean(scores[-100:]) >= 0.5: print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores[-100:]))) torch.save(agent.actor_local.state_dict(), 'checkpoint_actor.pth') torch.save(agent.critic_local.state_dict(), 'checkpoint_critic.pth') break # 创建 MADDPG 实例并训练 agent = MADDPG(n_agents=num_agents, state_dim=state_size, action_dim=action_size, hidden_dim=128) train(agent, env) ``` 以上代码示例中,Actor 和 Critic 网络均为全连接神经网络,其中 Actor 网络输出动作值,并使用 tanh 函数对动作进行了缩放。MADDPG 类定义了 MADDPG 算法的实现方法,其中包括了 Actor 和 Critic 的本地网络和目标网络,并定义了学习和软更新方法。ReplayBuffer 类实现了经验回放缓存。训练函数 train 利用以上类来实现 MADDPG 算法的训练。

相关推荐

最新推荐

recommend-type

pytorch之inception_v3的实现案例

今天小编就为大家分享一篇pytorch之inception_v3的实现案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

今天小编就为大家分享一篇Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch之添加BN的实现

今天小编就为大家分享一篇pytorch之添加BN的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch 可视化feature map的示例代码

今天小编就为大家分享一篇pytorch 可视化feature map的示例代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

软考-考生常见操作说明-202405101400-纯图版.pdf

软考官网--2024常见操作说明:包括如何绘制网络图、UML图、表格等 模拟作答系统是计算机技术与软件专业技术资格(水平)考试的电子化考试系统界面、作答过程的仿真系统,为各级别、各资格涉及输入和页面显示的部分题型提供体验性练习。
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

解释minorization-maximization (MM) algorithm,并给出matlab代码编写的例子

Minorization-maximization (MM) algorithm是一种常用的优化算法,用于求解非凸问题或含有约束的优化问题。该算法的基本思想是通过构造一个凸下界函数来逼近原问题,然后通过求解凸下界函数的最优解来逼近原问题的最优解。具体步骤如下: 1. 初始化参数 $\theta_0$,设 $k=0$; 2. 构造一个凸下界函数 $Q(\theta|\theta_k)$,使其满足 $Q(\theta_k|\theta_k)=f(\theta_k)$; 3. 求解 $Q(\theta|\theta_k)$ 的最优值 $\theta_{k+1}=\arg\min_\theta Q(
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。