用pytorch写添加了icm算法的ddpg代码
时间: 2023-06-18 09:03:46 浏览: 372
好的,以下是添加了ICM算法的DDPG代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class ICM(nn.Module):
def __init__(self, state_dim, action_dim, feature_dim=32, learning_rate=0.0001, beta=0.2):
super(ICM, self).__init__()
self.beta = beta
self.inv_fc1 = nn.Linear(feature_dim*2, 256)
self.inv_fc2 = nn.Linear(256, action_dim)
self.forward_fc1 = nn.Linear(state_dim+action_dim, 256)
self.forward_fc2 = nn.Linear(256, feature_dim)
self.inverse_optimizer = torch.optim.Adam(self.inv_fc1.parameters(), lr=learning_rate)
self.forward_optimizer = torch.optim.Adam(self.forward_fc1.parameters(), lr=learning_rate)
def forward(self, state, next_state, action):
state_action = torch.cat([state, action], 1)
state_action = F.relu(self.forward_fc1(state_action))
next_state_feature = self.forward_fc2(state_action)
state_feature = self.forward_fc2(F.relu(self.forward_fc1(torch.cat([state, torch.zeros_like(action)], 1))))
pred_action = self.inv_fc2(F.relu(self.inv_fc1(torch.cat([next_state_feature, state_feature], 1))))
return pred_action
def compute_loss(self, state, next_state, action, pred_action):
state_action = torch.cat([state, action], 1)
state_action = F.relu(self.forward_fc1(state_action))
next_state_feature = self.forward_fc2(state_action)
state_feature = self.forward_fc2(F.relu(self.forward_fc1(torch.cat([state, torch.zeros_like(action)], 1))))
inv_loss = F.mse_loss(pred_action, action)
forward_loss = F.mse_loss(next_state_feature, state_feature.detach())
return self.beta * inv_loss + (1 - self.beta) * forward_loss
def train_model(self, state, next_state, action):
self.inverse_optimizer.zero_grad()
self.forward_optimizer.zero_grad()
pred_action = self(state, next_state, action)
loss = self.compute_loss(state, next_state, action, pred_action)
loss.backward()
self.inverse_optimizer.step()
self.forward_optimizer.step()
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim1=128, hidden_dim2=128):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim1)
self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
self.fc3 = nn.Linear(hidden_dim2, 1)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action, hidden_dim1=128, hidden_dim2=128):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim1)
self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
self.fc3 = nn.Linear(hidden_dim2, action_dim)
self.max_action = max_action
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = self.max_action * torch.tanh(self.fc3(x))
return x
class DDPG:
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
self.icm = ICM(state_dim, action_dim)
self.max_action = max_action
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1))
return self.actor(state).cpu().data.numpy().flatten()
def update(self, replay_buffer, batch_size=100, discount=0.99, tau=0.005, icm_update=1):
state, action, next_state, reward, done = replay_buffer.sample(batch_size)
state = torch.FloatTensor(state)
action = torch.FloatTensor(action)
next_state = torch.FloatTensor(next_state)
reward = torch.FloatTensor(reward.reshape(-1, 1))
done = torch.FloatTensor(done.reshape(-1, 1))
predicted_action = self.icm(state, next_state, action)
# Q-function update
target_Q = self.critic_target(next_state, self.actor_target(next_state))
target_Q = reward + ((1 - done) * discount * target_Q).detach()
current_Q = self.critic(state, action)
critic_loss = F.mse_loss(current_Q, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Actor update
actor_loss = -self.critic(state, self.actor(state)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# ICM update
if icm_update:
icm_loss = self.icm.compute_loss(state, next_state, action, predicted_action)
self.icm.train_model(state, next_state, action)
# Target network update
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
```
注意,这里的ICM算法是基于论文 "Curiosity-driven Exploration by Self-supervised Prediction" 实现的。在使用时,可以通过调节`icm_update`参数来指定是否更新ICM模型。如果将其设为0,则只运行DDPG算法。
阅读全文