dqn代码实现
时间: 2023-07-28 20:07:01 浏览: 114
dqn_DQN算法_DQN_DQN实现_
5星 · 资源好评率100%
DQN(Deep Q-Network)是一种强化学习算法,结合了深度神经网络和Q-learning算法,用于解决离散动作空间下的决策问题。下面是DQN的代码实现(Python + PyTorch):
```python
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
# 定义DQN网络
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super(DQN, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义经验回放缓存
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
# 定义DQN智能体
class DQNAgent:
def __init__(self, state_dim, action_dim, lr, gamma, epsilon, buffer_capacity, batch_size):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.state_dim = state_dim
self.action_dim = action_dim
self.lr = lr
self.gamma = gamma
self.epsilon = epsilon
self.batch_size = batch_size
self.buffer = ReplayBuffer(buffer_capacity)
self.policy_net = DQN(state_dim, action_dim).to(self.device)
self.target_net = DQN(state_dim, action_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.loss_fn = nn.SmoothL1Loss()
def act(self, state):
if np.random.rand() < self.epsilon:
return random.randrange(self.action_dim)
with torch.no_grad():
state = torch.FloatTensor(state).to(self.device)
q_values = self.policy_net(state)
action = q_values.argmax().item()
return action
def update(self):
if len(self.buffer) < self.batch_size:
return
state, action, reward, next_state, done = self.buffer.sample(self.batch_size)
state = torch.FloatTensor(state).to(self.device)
action = torch.LongTensor(action).to(self.device)
reward = torch.FloatTensor(reward).to(self.device)
next_state = torch.FloatTensor(next_state).to(self.device)
done = torch.FloatTensor(done).to(self.device)
q_values = self.policy_net(state).gather(1, action.unsqueeze(-1)).squeeze(-1)
next_q_values = self.target_net(next_state).max(1)[0]
expected_q_values = reward + self.gamma * next_q_values * (1 - done)
loss = self.loss_fn(q_values, expected_q_values.detach())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def update_target(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path):
torch.save(self.policy_net.state_dict(), path)
def load(self, path):
self.policy_net.load_state_dict(torch.load(path))
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
```
该实现包括以下组件:
1. DQN网络:定义了一个三层全连接神经网络,输入为状态,输出为每个动作的Q值。
2. 经验回放缓存:用于存储智能体与环境交互的经验,以便于随机采样。
3. DQN智能体:包括动作选择、经验回放、网络更新等功能。
在训练中,智能体与环境进行交互,得到状态、动作、奖励、下一个状态和是否终止的信息,将其存储到经验回放缓存中。然后,智能体从经验回放缓存中随机采样一批数据,计算Q值的误差并更新网络参数。同时,定期将策略网络的参数复制给目标网络,以稳定训练过程。
阅读全文