强化学习DDPG代码
时间: 2024-03-10 11:41:46 浏览: 173
DDPG(Deep Deterministic Policy Gradient)是一种强化学习算法,用于解决连续动作空间的问题。下面是一个简单的DDPG代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 定义Actor网络
class Actor(nn.Module):
def __init__(self, state_dim, action_dim):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, action_dim)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
x = torch.tanh(self.fc3(x))
return x
# 定义Critic网络
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义DDPG算法
class DDPG:
def __init__(self, state_dim, action_dim):
self.actor = Actor(state_dim, action_dim)
self.actor_target = Actor(state_dim, action_dim)
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=0.001)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=0.001)
self.loss_fn = nn.MSELoss()
self.memory = []
self.batch_size = 64
self.gamma = 0.99
self.tau = 0.001
def select_action(self, state):
state = torch.FloatTensor(state)
action = self.actor(state).detach().numpy()
return action
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def replay(self):
if len(self.memory) < self.batch_size:
return
batch = np.random.choice(len(self.memory), self.batch_size, replace=False)
state_batch = torch.FloatTensor([self.memory[i][0] for i in batch])
action_batch = torch.FloatTensor([self.memory[i][1] for i in batch])
reward_batch = torch.FloatTensor([self.memory[i][2] for i in batch])
next_state_batch = torch.FloatTensor([self.memory[i][3] for i in batch])
done_batch = torch.FloatTensor([self.memory[i][4] for i in batch])
# 更新Critic网络
next_action_batch = self.actor_target(next_state_batch)
target_q = reward_batch + self.gamma * (1 - done_batch) * self.critic_target(next_state_batch, next_action_batch).detach()
q_values = self.critic(state_batch, action_batch)
critic_loss = self.loss_fn(q_values, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 更新Actor网络
policy_loss = -self.critic(state_batch, self.actor(state_batch)).mean()
self.actor_optimizer.zero_grad()
policy_loss.backward()
self.actor_optimizer.step()
# 更新目标网络
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
# 创建DDPG对象
state_dim = 4
action_dim = 2
ddpg = DDPG(state_dim, action_dim)
# 训练DDPG
for episode in range(1000):
state = env.reset()
done = False
total_reward = 0
while not done:
action = ddpg.select_action(state)
next_state, reward, done, _ = env.step(action)
ddpg.remember(state, action, reward, next_state, done)
ddpg.replay()
state = next_state
total_reward += reward
print("Episode: {}, Total Reward: {}".format(episode, total_reward))
```
这段代码实现了一个简单的DDPG算法,包括Actor和Critic网络的定义、网络参数的更新、经验回放等步骤。你可以根据自己的需求进行修改和扩展。
阅读全文