DQN的pytorch代码
时间: 2023-06-26 09:07:35 浏览: 192
DQN代码实战,gym经典CartPole(小车倒立摆)模型,纯PyTorch框架,代码中包含4种DQN变体,注释清晰。
5星 · 资源好评率100%
以下是基于PyTorch实现的DQN代码,用于训练和测试Atari游戏:
```python
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from collections import deque
import gym
# 定义超参数
BUFFER_SIZE = 10000 # 经验回放缓存区大小
BATCH_SIZE = 32 # 批次大小
GAMMA = 0.99 # 折扣因子
EPS_START = 1.0 # 贪心策略初始值
EPS_END = 0.01 # 贪心策略最终值
EPS_DECAY = 1000000 # 贪心策略衰减率
TARGET_UPDATE = 1000 # 目标网络更新频率
# 实现经验回放缓存区
class ReplayBuffer(object):
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
# 实现DQN模型
class DQN(nn.Module):
def __init__(self, h, w, outputs):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
self.bn3 = nn.BatchNorm2d(32)
# 线性层输入数量的计算
def conv2d_size_out(size, kernel_size=5, stride=2):
return (size - (kernel_size - 1) - 1) // stride + 1
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
linear_input_size = convw * convh * 32
self.head = nn.Linear(linear_input_size, outputs)
def forward(self, x):
x = x/255.0
x = nn.functional.relu(self.bn1(self.conv1(x)))
x = nn.functional.relu(self.bn2(self.conv2(x)))
x = nn.functional.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1))
# 实现Agent
class Agent:
def __init__(self, env):
self.env = env
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.policy_net = DQN(env.observation_space.shape[1], env.observation_space.shape[2], env.action_space.n).to(self.device)
self.target_net = DQN(env.observation_space.shape[1], env.observation_space.shape[2], env.action_space.n).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.RMSprop(self.policy_net.parameters())
self.memory = ReplayBuffer(BUFFER_SIZE)
self.steps_done = 0
def select_action(self, state):
eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * self.steps_done / EPS_DECAY)
self.steps_done += 1
if random.random() > eps_threshold:
with torch.no_grad():
return self.policy_net(state.to(self.device)).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(self.env.action_space.n)]], device=self.device, dtype=torch.long)
def optimize_model(self):
if len(self.memory) < BATCH_SIZE:
return
state, action, reward, next_state, done = self.memory.sample(BATCH_SIZE)
state_batch = torch.cat(state).to(self.device)
action_batch = torch.cat(action).to(self.device)
reward_batch = torch.cat(reward).to(self.device)
next_state_batch = torch.cat(next_state).to(self.device)
done_batch = torch.cat(done).to(self.device)
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE, device=self.device)
next_state_values[~done_batch] = self.target_net(next_state_batch[~done_batch]).max(1)[0].detach()
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
loss = nn.functional.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
self.optimizer.zero_grad()
loss.backward()
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
def train(self, num_episodes):
for i_episode in range(num_episodes):
state = self.env.reset()
for t in range(10000):
action = self.select_action(torch.FloatTensor(state).unsqueeze(0))
next_state, reward, done, _ = self.env.step(action.item())
reward = torch.tensor([reward], device=self.device)
self.memory.push(torch.FloatTensor(state).unsqueeze(0), action, reward, torch.FloatTensor(next_state).unsqueeze(0), torch.tensor([done], device=self.device))
state = next_state
self.optimize_model()
if t % TARGET_UPDATE == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
if done:
break
print('Episode: {} steps: {}'.format(i_episode, t))
def test(self):
state = self.env.reset()
for t in range(10000):
action = self.policy_net(torch.FloatTensor(state).unsqueeze(0).to(self.device)).max(1)[1].view(1, 1)
next_state, reward, done, _ = self.env.step(action.item())
state = next_state
self.env.render()
if done:
break
self.env.close()
# 运行训练和测试
if __name__ == '__main__':
env = gym.make('Pong-v0')
env = env.unwrapped
agent = Agent(env)
agent.train(num_episodes=100)
agent.test()
```
这是一个简单的DQN实现,您可以根据需要进行修改、扩展和优化。
阅读全文