Actor-Critic
时间: 2024-01-18 12:04:02 浏览: 94
Actor-Critic是一种强化学习算法,它结合了策略梯度和值函数的优点。在Actor-Critic算法中,Actor使用策略函数生成动作并与环境交互,Critic使用价值函数评估Actor的表现并指导Actor下一步的动作。Actor和Critic都是神经网络,需要进行梯度更新,互相依赖。Actor-Critic算法可以用于解决连续动作空间的问题,例如机器人控制和游戏AI等领域。
以下是一个Actor-Critic算法的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import gym
# 定义Actor神经网络
class Actor(nn.Module):
def __init__(self, input_dim, output_dim):
super(Actor, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, output_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.softmax(self.fc2(x))
return x
# 定义Critic神经网络
class Critic(nn.Module):
def __init__(self, input_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 定义Actor-Critic算法
class ActorCritic:
def __init__(self, env):
self.env = env
self.obs_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.n
self.actor = Actor(self.obs_dim, self.action_dim)
self.critic = Critic(self.obs_dim)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=0.01)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=0.01)
def train(self, max_episodes=1000, max_steps=1000, gamma=0.99):
for i in range(max_episodes):
obs = self.env.reset()
done = False
total_reward = 0
for t in range(max_steps):
# 选择动作
action_probs = self.actor(torch.FloatTensor(obs))
action = torch.multinomial(action_probs, 1).item()
# 执行动作
next_obs, reward, done, _ = self.env.step(action)
# 更新Critic
value = self.critic(torch.FloatTensor(obs)).item()
next_value = self.critic(torch.FloatTensor(next_obs)).item()
td_error = reward + gamma * next_value - value
critic_loss = td_error ** 2
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 更新Actor
advantage = td_error
log_prob = torch.log(action_probs[action])
actor_loss = -log_prob * advantage
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
obs = next_obs
total_reward += reward
if done:
break
print("Episode {}: Total reward = {}".format(i, total_reward))
# 使用CartPole环境测试Actor-Critic算法
env = gym.make('CartPole-v0')
ac = ActorCritic(env)
ac.train()
```
阅读全文