DDPG结合GCN做预测的代码
时间: 2023-11-15 21:06:03 浏览: 77
DDPG(Deep Deterministic Policy Gradient)结合 GCN(Graph Convolutional Network)可以用于处理图数据的预测任务。以下是一个简单的示例代码,展示了如何使用 PyTorch 实现 DDPG-GCN 模型:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
from torch_geometric.nn import GCNConv
# 创建 DDGP-GCN 的 Actor 类
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.gcn = GCNConv(hidden_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, action_dim)
def forward(self, state, adj):
x = F.relu(self.fc1(state))
x = self.gcn(x, adj)
x = F.relu(x)
x = self.fc2(x)
x = torch.tanh(x)
return x
# 创建 DDPG-GCN 的 Critic 类
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.gcn1 = GCNConv(hidden_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 1)
def forward(self, state, action, adj):
x = torch.cat([state, action], 1)
x = F.relu(self.fc1(x))
x = self.gcn1(x, adj)
x = F.relu(x)
x = self.fc2(x)
return x
# 创建 DDPG-GCN 的智能体类
class Agent:
def __init__(self, state_dim, action_dim, hidden_dim, gamma=0.99, tau=1e-2, lr_actor=1e-3, lr_critic=1e-3, buffer_size=100000, batch_size=64):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.actor = Actor(state_dim, action_dim, hidden_dim).to(self.device)
self.actor_target = Actor(state_dim, action_dim, hidden_dim).to(self.device)
self.critic = Critic(state_dim, action_dim, hidden_dim).to(self.device)
self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(self.device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr_actor)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr_critic)
self.buffer = deque(maxlen=buffer_size)
self.batch_size = batch_size
self.gamma = gamma
self.tau = tau
# 策略网络(Actor)选择动作
def select_action(self, state, adj):
state = torch.FloatTensor(state).to(self.device)
adj = torch.FloatTensor(adj).to(self.device)
self.actor.eval()
with torch.no_grad():
action = self.actor(state, adj).cpu().data.numpy()
self.actor.train()
return action
# 存储(状态,动作,奖励,下一个状态)元组到缓存中
def remember(self, state, action, reward, next_state, adj):
state = torch.FloatTensor(state).to(self.device)
action = torch.FloatTensor(action).to(self.device)
reward = torch.FloatTensor([reward]).to(self.device)
next_state = torch.FloatTensor(next_state).to(self.device)
adj = torch.FloatTensor(adj).to(self.device)
self.buffer.append((state, action, reward, next_state, adj))
# 从缓存中随机抽样,进行训练
def train(self):
if len(self.buffer) < self.batch_size:
return
# 从缓存中随机抽样
batch = random.sample(self.buffer, self.batch_size)
state, action, reward, next_state, adj = zip(*batch)
state = torch.cat(state)
action = torch.cat(action)
reward = torch.cat(reward)
next_state = torch.cat(next_state)
adj = torch.cat(adj)
# 计算 Q 目标值
next_action = self.actor_target(next_state, adj)
q_target = reward + self.gamma * self.critic_target(next_state, next_action, adj).detach()
q_target = q_target.to(self.device)
# 更新 Critic 网络
q_value = self.critic(state, action, adj)
critic_loss = F.mse_loss(q_value, q_target)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 更新 Actor 网络
actor_loss = -self.critic(state, self.actor(state, adj), adj).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 更新目标网络(Target Network)
for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
# 保存模型
def save(self, filename):
torch.save({
'actor_state_dict': self.actor.state_dict(),
'critic_state_dict': self.critic.state_dict(),
'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
}, filename)
# 加载模型
def load(self, filename):
checkpoint = torch.load(filename)
self.actor.load_state_dict(checkpoint['actor_state_dict'])
self.critic.load_state_dict(checkpoint['critic_state_dict'])
self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
```
在上述代码中,我们首先定义了一个 GCN 网络,然后将其嵌入到 DDPG 智能体的 Actor 和 Critic 网络中。我们还定义了智能体的存储缓存、训练函数以及保存/加载函数。最后,我们可以使用以下代码来训练 DDPG-GCN 模型:
``` python
agent = Agent(state_dim, action_dim, hidden_dim)
for episode in range(num_episodes):
state = env.reset()
for step in range(num_steps):
action = agent.select_action(state, adj)
next_state, reward, done, _ = env.step(action)
agent.remember(state, action, reward, next_state, adj)
agent.train()
if done:
break
state = next_state
agent.save('ddpg_gcn.pt')
```
在训练过程中,我们首先使用 `select_action` 函数从智能体的 Actor 网络中选择动作。我们使用 `remember` 函数将(状态,动作,奖励,下一个状态)元组存储到智能体的缓存中。然后,我们使用 `train` 函数从缓存中随机抽样并进行训练。最后,我们使用 `save` 函数保存模型。
阅读全文