基于pytorch实现a3c算法的代码
时间: 2024-02-18 10:05:27 浏览: 89
这里提供一个简单的基于PyTorch实现A3C算法的代码示例,仅供参考。
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gym
import numpy as np
import multiprocessing as mp
# 定义Actor网络
class Actor(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Actor, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.softmax(self.fc3(x), dim=1)
return x
# 定义Critic网络
class Critic(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Critic, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义Agent类
class Agent:
def __init__(self, input_size, hidden_size, output_size):
self.actor = Actor(input_size, hidden_size, output_size)
self.critic = Critic(input_size, hidden_size, 1)
self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=0.001)
self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=0.001)
def act(self, state):
state = torch.from_numpy(state).float().unsqueeze(0)
probs = self.actor(state)
dist = Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(action)
def update(self, rewards, log_probs, values):
returns = []
advs = []
R = 0
for r in reversed(rewards):
R = r + 0.99 * R
returns.insert(0, R)
returns = torch.tensor(returns)
for r, v in zip(returns, values):
advs.append(r - v)
advs = torch.tensor(advs)
log_probs = torch.stack(log_probs)
values = torch.stack(values)
loss_actor = (-log_probs * advs).mean()
loss_critic = nn.MSELoss()(values, returns.unsqueeze(1))
self.optimizer_actor.zero_grad()
loss_actor.backward()
self.optimizer_actor.step()
self.optimizer_critic.zero_grad()
loss_critic.backward()
self.optimizer_critic.step()
# 定义训练函数
def train(rank, agent, env, num_episodes):
env.seed(rank)
torch.manual_seed(rank)
for i_episode in range(num_episodes):
state = env.reset()
done = False
rewards = []
log_probs = []
values = []
while not done:
action, log_prob = agent.act(state)
next_state, reward, done, _ = env.step(action)
state = next_state
rewards.append(reward)
log_probs.append(log_prob)
values.append(agent.critic(torch.from_numpy(state).float()))
agent.update(rewards, log_probs, values)
# 主函数
if __name__ == '__main__':
env = gym.make('CartPole-v0')
num_processes = mp.cpu_count()
num_episodes = 1000
agent = Agent(env.observation_space.shape[0], 128, env.action_space.n)
processes = []
for rank in range(num_processes):
p = mp.Process(target=train, args=(rank, agent, env, num_episodes//num_processes))
p.start()
processes.append(p)
for p in processes:
p.join()
```
以上代码仅作演示用途,具体实现细节可能需要根据具体问题进行调整。同时需要注意的是,A3C算法的效果通常取决于超参数的选择和训练过程的稳定性,因此需要进行充分的实验和调试。
阅读全文