多智能体编队actor-critic pytorch代码
时间: 2024-01-11 18:02:39 浏览: 119
这里是一个简单的多智能体编队actor-critic pytorch代码实现:
```
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class Actor(nn.Module):
def __init__(self, obs_size, act_size):
super(Actor, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, act_size),
nn.Softmax(dim=-1)
)
def forward(self, obs):
return self.net(obs)
class Critic(nn.Module):
def __init__(self, obs_size):
super(Critic, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
def forward(self, obs):
return self.net(obs)
class Agent:
def __init__(self, obs_size, act_size, lr=0.001, gamma=0.99):
self.actor = Actor(obs_size, act_size)
self.critic = Critic(obs_size)
self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=lr)
self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr)
self.gamma = gamma
def act(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
probs = self.actor(state)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
return action.item()
def update(self, states, actions, rewards, next_states, dones):
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions).unsqueeze(1)
rewards = torch.FloatTensor(rewards).unsqueeze(1)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones).unsqueeze(1)
# update critic
values = self.critic(states)
next_values = self.critic(next_states)
target = rewards + (1 - dones) * self.gamma * next_values
td_error = target - values
critic_loss = td_error.pow(2).mean()
self.optimizer_critic.zero_grad()
critic_loss.backward()
self.optimizer_critic.step()
# update actor
probs = self.actor(states)
dist = torch.distributions.Categorical(probs)
log_probs = dist.log_prob(actions)
actor_loss = -(log_probs * td_error.detach()).mean()
self.optimizer_actor.zero_grad()
actor_loss.backward()
self.optimizer_actor.step()
def run_episode(env, agents):
states = env.reset()
total_rewards = np.zeros(len(agents))
while True:
actions = [agent.act(state) for agent, state in zip(agents, states)]
next_states, rewards, dones, _ = env.step(actions)
for i, agent in enumerate(agents):
agent.update(states[i], actions[i], rewards[i], next_states[i], dones[i])
total_rewards[i] += rewards[i]
states = next_states
if dones.all():
break
return total_rewards
if __name__ == '__main__':
env = gym.make('MultiDiscrete-v0')
obs_size = env.observation_space.shape[0]
act_size = env.action_space.nvec.tolist()
agents = [Agent(obs_size, act, lr=0.001, gamma=0.99) for act in act_size]
for i in range(1000):
rewards = run_episode(env, agents)
print(f'episode {i}: {rewards}')
阅读全文