DQN pytorch代码
时间: 2023-12-20 09:32:07 浏览: 111
DQN代码实战,gym经典CartPole(小车倒立摆)模型,纯PyTorch框架,代码中包含4种DQN变体,注释清晰。
5星 · 资源好评率100%
以下是一个使用PyTorch实现的DQN算法的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
# 定义DQN网络
class DQN(nn.Module):
def __init__(self, input_size, output_size):
super(DQN, self).__init__()
self.fc1 = nn.Linear(input_size, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义经验回放缓冲区
class ReplayBuffer():
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)
def __len__(self):
return len(self.buffer)
# 定义DQN算法
class DQNAgent():
def __init__(self, state_size, action_size, batch_size, gamma, epsilon, epsilon_decay, epsilon_min, learning_rate, target_update):
self.state_size = state_size
self.action_size = action_size
self.batch_size = batch_size
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.learning_rate = learning_rate
self.target_update = target_update
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.policy_net = DQN(state_size, action_size).to(self.device)
self.target_net = DQN(state_size, action_size).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.memory = ReplayBuffer(10000)
self.steps_done = 0
def select_action(self, state):
if random.random() > self.epsilon:
with torch.no_grad():
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
q_values = self.policy_net(state)
action = q_values.max(1)[1].item()
else:
action = random.randrange(self.action_size)
self.steps_done += 1
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return action
def update_model(self):
if len(self.memory) < self.batch_size:
return
states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
states = torch.tensor(states, dtype=torch.float32).to(self.device)
actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1).to(self.device)
rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(self.device)
next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(self.device)
q_values = self.policy_net(states).gather(1, actions)
next_q_values = self.target_net(next_states).max(1)[0].unsqueeze(1)
expected_q_values = rewards + self.gamma * next_q_values * (1 - dones)
loss = F.smooth_l1_loss(q_values, expected_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.steps_done % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# 创建DQNAgent实例
state_size = 4
action_size = 2
batch_size = 32
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.01
learning_rate = 0.001
target_update = 100
agent = DQNAgent(state_size, action_size, batch_size, gamma, epsilon, epsilon_decay, epsilon_min, learning_rate, target_update)
# 训练DQN
for episode in range(1000):
state = env.reset()
done = False
total_reward = 0
while not done:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
agent.memory.push(state, action, reward, next_state, done)
state = next_state
total_reward += reward
agent.update_model()
if episode % 100 == 0:
print("Episode: {}, Total Reward: {}".format(episode, total_reward))
# 使用训练好的DQN进行预测
state = env.reset()
done = False
while not done:
action = agent.select_action(state)
state, reward, done, _ = env.step(action)
env.render()
```
阅读全文