masac算法代码
时间: 2023-07-03 19:27:32 浏览: 236
以下是 MASAC 算法的代码示例,其中包括了 actor 和 critic 的网络结构以及训练过程。
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from collections import deque
import random
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action, hidden_size=256):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, action_dim)
self.max_action = max_action
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = self.max_action * torch.tanh(self.fc3(x))
return x
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=256):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class MASAC:
def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.005, alpha=0.2, actor_lr=1e-3, critic_lr=1e-3, batch_size=256, memory_size=1000000):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.actor = Actor(state_dim, action_dim, max_action).to(self.device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(self.device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic1 = Critic(state_dim, action_dim).to(self.device)
self.critic1_target = Critic(state_dim, action_dim).to(self.device)
self.critic1_target.load_state_dict(self.critic1.state_dict())
self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=critic_lr)
self.critic2 = Critic(state_dim, action_dim).to(self.device)
self.critic2_target = Critic(state_dim, action_dim).to(self.device)
self.critic2_target.load_state_dict(self.critic2.state_dict())
self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=critic_lr)
self.discount = discount
self.tau = tau
self.alpha = alpha
self.batch_size = batch_size
self.memory = deque(maxlen=memory_size)
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
return self.actor(state).cpu().data.numpy().flatten()
def store_transition(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def train(self):
if len(self.memory) < self.batch_size:
return
batch = random.sample(self.memory, self.batch_size)
state_batch = torch.FloatTensor(np.array([transition[0] for transition in batch])).to(self.device)
action_batch = torch.FloatTensor(np.array([transition[1] for transition in batch])).to(self.device)
reward_batch = torch.FloatTensor(np.array([transition[2] for transition in batch])).to(self.device)
next_state_batch = torch.FloatTensor(np.array([transition[3] for transition in batch])).to(self.device)
done_batch = torch.FloatTensor(np.array([transition[4] for transition in batch])).to(self.device)
# Critic Update
with torch.no_grad():
next_actions = self.actor_target(next_state_batch)
noise = torch.randn_like(next_actions) * self.alpha
next_actions = (next_actions + noise).clamp(-self.actor.max_action, self.actor.max_action)
target1 = self.critic1_target(next_state_batch, next_actions)
target2 = self.critic2_target(next_state_batch, next_actions)
target = torch.min(target1, target2)
target = reward_batch + self.discount * (1 - done_batch) * target
current1 = self.critic1(state_batch, action_batch)
current2 = self.critic2(state_batch, action_batch)
critic1_loss = F.mse_loss(current1, target)
critic2_loss = F.mse_loss(current2, target)
self.critic1_optimizer.zero_grad()
critic1_loss.backward()
self.critic1_optimizer.step()
self.critic2_optimizer.zero_grad()
critic2_loss.backward()
self.critic2_optimizer.step()
# Actor Update
actions = self.actor(state_batch)
actor_loss = -self.critic1(state_batch, actions).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update Target Networks
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
```
这里的 MASAC 算法与 DDPG 算法类似,只是多了一个 actor target 和多个 critic。其中 actor target 用于计算 critic 的 target 值,多个 critic 用于减小 Q 值的估计误差。具体的训练过程可以参考代码中的注释。