pytorch架构的DDPG算法代码
时间: 2023-08-05 13:09:43 浏览: 162
以下是基于PyTorch的DDPG算法代码:
首先,需要导入相关的包:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
```
接下来,我们需要定义Actor和Critic网络:
```python
# 定义Actor网络
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 400)
self.fc2 = nn.Linear(400, 300)
self.fc3 = nn.Linear(300, action_dim)
self.max_action = max_action
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = self.max_action * torch.tanh(self.fc3(x))
return x
# 定义Critic网络
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 400)
self.fc2 = nn.Linear(400 , 300)
self.fc3 = nn.Linear(300, 1)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
```
接下来,我们定义DDPG算法的主要函数:
```python
class DDPG(object):
def __init__(self, state_dim, action_dim, max_action):
# 定义Actor和Critic
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
# 定义一些超参数
self.state_dim = state_dim
self.action_dim = action_dim
self.max_action = max_action
self.replay_buffer = deque(maxlen=1000000)
self.batch_size = 64
self.discount = 0.99
self.tau = 0.005
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
return self.actor(state).cpu().data.numpy().flatten()
def add_to_buffer(self, state, action, reward, next_state, done):
self.replay_buffer.append((state, action, reward, next_state, done))
def train(self):
if len(self.replay_buffer) < self.batch_size:
return
# 从经验池中采样一个batch的数据
batch = random.sample(self.replay_buffer, self.batch_size)
state = torch.FloatTensor(np.array([e[0] for e in batch])).to(device)
action = torch.FloatTensor(np.array([e[1] for e in batch])).to(device)
reward = torch.FloatTensor(np.array([e[2] for e in batch])).to(device)
next_state = torch.FloatTensor(np.array([e[3] for e in batch])).to(device)
done = torch.FloatTensor(np.array([e[4] for e in batch])).to(device)
# 计算target Q值
target_Q = self.critic_target(next_state, self.actor_target(next_state))
target_Q = reward + ((1 - done) * self.discount * target_Q).detach()
# 计算当前Q值
current_Q = self.critic(state, action)
# 计算Critic损失
critic_loss = F.mse_loss(current_Q, target_Q)
# 更新Critic网络
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 计算Actor损失
actor_loss = -self.critic(state, self.actor(state)).mean()
# 更新Actor网络
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 更新Actor和Critic的目标网络
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
```
最后,我们可以使用DDPG算法来训练和测试模型:
```python
# 定义一些环境参数
state_dim = 3
action_dim = 1
max_action = 1
# 创建DDPG对象
ddpg = DDPG(state_dim, action_dim, max_action)
# 训练模型
for i in range(10000):
state = np.random.randn(state_dim)
action = ddpg.select_action(state)
next_state = np.random.randn(state_dim)
reward = np.random.randn(1)
done = np.random.randint(0, 2)
ddpg.add_to_buffer(state, action, reward, next_state, done)
ddpg.train()
# 测试模型
state = np.random.randn(state_dim)
action = ddpg.select_action(state)
print(action)
```
这就是基于PyTorch的DDPG算法代码。
阅读全文