用pytorch写添加了icm算法的ddpg代码

时间: 2023-06-18 09:03:46 浏览: 160
好的,以下是添加了ICM算法的DDPG代码: ```python import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class ICM(nn.Module): def __init__(self, state_dim, action_dim, feature_dim=32, learning_rate=0.0001, beta=0.2): super(ICM, self).__init__() self.beta = beta self.inv_fc1 = nn.Linear(feature_dim*2, 256) self.inv_fc2 = nn.Linear(256, action_dim) self.forward_fc1 = nn.Linear(state_dim+action_dim, 256) self.forward_fc2 = nn.Linear(256, feature_dim) self.inverse_optimizer = torch.optim.Adam(self.inv_fc1.parameters(), lr=learning_rate) self.forward_optimizer = torch.optim.Adam(self.forward_fc1.parameters(), lr=learning_rate) def forward(self, state, next_state, action): state_action = torch.cat([state, action], 1) state_action = F.relu(self.forward_fc1(state_action)) next_state_feature = self.forward_fc2(state_action) state_feature = self.forward_fc2(F.relu(self.forward_fc1(torch.cat([state, torch.zeros_like(action)], 1)))) pred_action = self.inv_fc2(F.relu(self.inv_fc1(torch.cat([next_state_feature, state_feature], 1)))) return pred_action def compute_loss(self, state, next_state, action, pred_action): state_action = torch.cat([state, action], 1) state_action = F.relu(self.forward_fc1(state_action)) next_state_feature = self.forward_fc2(state_action) state_feature = self.forward_fc2(F.relu(self.forward_fc1(torch.cat([state, torch.zeros_like(action)], 1)))) inv_loss = F.mse_loss(pred_action, action) forward_loss = F.mse_loss(next_state_feature, state_feature.detach()) return self.beta * inv_loss + (1 - self.beta) * forward_loss def train_model(self, state, next_state, action): self.inverse_optimizer.zero_grad() self.forward_optimizer.zero_grad() pred_action = self(state, next_state, action) loss = self.compute_loss(state, next_state, action, pred_action) loss.backward() self.inverse_optimizer.step() self.forward_optimizer.step() class Critic(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim1=128, hidden_dim2=128): super(Critic, self).__init__() self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim1) self.fc2 = nn.Linear(hidden_dim1, hidden_dim2) self.fc3 = nn.Linear(hidden_dim2, 1) def forward(self, state, action): x = torch.cat([state, action], 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action, hidden_dim1=128, hidden_dim2=128): super(Actor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim1) self.fc2 = nn.Linear(hidden_dim1, hidden_dim2) self.fc3 = nn.Linear(hidden_dim2, action_dim) self.max_action = max_action def forward(self, state): x = F.relu(self.fc1(state)) x = F.relu(self.fc2(x)) x = self.max_action * torch.tanh(self.fc3(x)) return x class DDPG: def __init__(self, state_dim, action_dim, max_action): self.actor = Actor(state_dim, action_dim, max_action) self.actor_target = Actor(state_dim, action_dim, max_action) self.actor_optimizer = torch.optim.Adam(self.actor.parameters()) self.critic = Critic(state_dim, action_dim) self.critic_target = Critic(state_dim, action_dim) self.critic_optimizer = torch.optim.Adam(self.critic.parameters()) self.icm = ICM(state_dim, action_dim) self.max_action = max_action def select_action(self, state): state = torch.FloatTensor(state.reshape(1, -1)) return self.actor(state).cpu().data.numpy().flatten() def update(self, replay_buffer, batch_size=100, discount=0.99, tau=0.005, icm_update=1): state, action, next_state, reward, done = replay_buffer.sample(batch_size) state = torch.FloatTensor(state) action = torch.FloatTensor(action) next_state = torch.FloatTensor(next_state) reward = torch.FloatTensor(reward.reshape(-1, 1)) done = torch.FloatTensor(done.reshape(-1, 1)) predicted_action = self.icm(state, next_state, action) # Q-function update target_Q = self.critic_target(next_state, self.actor_target(next_state)) target_Q = reward + ((1 - done) * discount * target_Q).detach() current_Q = self.critic(state, action) critic_loss = F.mse_loss(current_Q, target_Q) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # Actor update actor_loss = -self.critic(state, self.actor(state)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # ICM update if icm_update: icm_loss = self.icm.compute_loss(state, next_state, action, predicted_action) self.icm.train_model(state, next_state, action) # Target network update for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) ``` 注意,这里的ICM算法是基于论文 "Curiosity-driven Exploration by Self-supervised Prediction" 实现的。在使用时,可以通过调节`icm_update`参数来指定是否更新ICM模型。如果将其设为0,则只运行DDPG算法。

相关推荐

最新推荐

recommend-type

pytorch之添加BN的实现

今天小编就为大家分享一篇pytorch之添加BN的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch下使用LSTM神经网络写诗实例

今天小编就为大家分享一篇pytorch下使用LSTM神经网络写诗实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch三层全连接层实现手写字母识别方式

今天小编就为大家分享一篇pytorch三层全连接层实现手写字母识别方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

pytorch 利用lstm做mnist手写数字识别分类的实例

今天小编就为大家分享一篇pytorch 利用lstm做mnist手写数字识别分类的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

Pytorch实现的手写数字mnist识别功能完整示例

主要介绍了Pytorch实现的手写数字mnist识别功能,结合完整实例形式分析了Pytorch模块手写字识别具体步骤与相关实现技巧,需要的朋友可以参考下
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

解释minorization-maximization (MM) algorithm,并给出matlab代码编写的例子

Minorization-maximization (MM) algorithm是一种常用的优化算法,用于求解非凸问题或含有约束的优化问题。该算法的基本思想是通过构造一个凸下界函数来逼近原问题,然后通过求解凸下界函数的最优解来逼近原问题的最优解。具体步骤如下: 1. 初始化参数 $\theta_0$,设 $k=0$; 2. 构造一个凸下界函数 $Q(\theta|\theta_k)$,使其满足 $Q(\theta_k|\theta_k)=f(\theta_k)$; 3. 求解 $Q(\theta|\theta_k)$ 的最优值 $\theta_{k+1}=\arg\min_\theta Q(
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。