dqn的pytorch代码
时间: 2023-09-22 14:09:46 浏览: 171
pytorch测试代码
以下是一个基于PyTorch的DQN(Deep Q-Network)的示例代码,用于解决OpenAI Gym中的CartPole问题:
```python
import gym
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
# 定义 DQN 网络
class DQN(nn.Module):
def __init__(self, observation_space_size, action_space_size):
super(DQN, self).__init__()
self.fc1 = nn.Linear(observation_space_size, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_space_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义经验回放缓存
class ReplayBuffer():
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
# 定义 DQN 算法
class DQNAgent():
def __init__(self, observation_space_size, action_space_size, replay_buffer_capacity=10000,
batch_size=32, gamma=0.99, learning_rate=1e-3, eps_start=1.0, eps_end=0.01, eps_decay=200):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.observation_space_size = observation_space_size
self.action_space_size = action_space_size
self.batch_size = batch_size
self.gamma = gamma
self.eps_start = eps_start
self.eps_end = eps_end
self.eps_decay = eps_decay
self.eps_decay_rate = (eps_start - eps_end) / eps_decay
self.steps_done = 0
self.policy_net = DQN(observation_space_size, action_space_size).to(self.device)
self.target_net = DQN(observation_space_size, action_space_size).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.replay_buffer = ReplayBuffer(replay_buffer_capacity)
def select_action(self, state):
self.steps_done += 1
epsilon = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1. * self.steps_done / self.eps_decay)
if random.random() < epsilon:
return random.randrange(self.action_space_size)
else:
with torch.no_grad():
state_tensor = torch.tensor(state, dtype=torch.float32).to(self.device)
q_values = self.policy_net(state_tensor).cpu().numpy()
return q_values.argmax()
def optimize_model(self):
if len(self.replay_buffer) < self.batch_size:
return
state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
state_tensor = torch.tensor(state, dtype=torch.float32).to(self.device)
action_tensor = torch.tensor(action, dtype=torch.long).to(self.device)
reward_tensor = torch.tensor(reward, dtype=torch.float32).to(self.device)
next_state_tensor = torch.tensor(next_state, dtype=torch.float32).to(self.device)
done_tensor = torch.tensor(done, dtype=torch.float32).to(self.device)
q_values = self.policy_net(state_tensor).gather(1, action_tensor.unsqueeze(1)).squeeze(1)
next_q_values = self.target_net(next_state_tensor).max(1)[0]
expected_q_values = reward_tensor + (1 - done_tensor) * self.gamma * next_q_values
loss = F.mse_loss(q_values, expected_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def remember(self, state, action, reward, next_state, done):
self.replay_buffer.push(state, action, reward, next_state, done)
# 训练 DQN 算法
def train_dqn(env, agent, num_episodes):
episode_lengths = []
for i in range(num_episodes):
state = env.reset()
episode_reward = 0
while True:
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
agent.remember(state, action, reward, next_state, done)
agent.optimize_model()
state = next_state
episode_reward += reward
if done:
episode_lengths.append(episode_reward)
break
if i % 20 == 0:
print("Episode: {}, average reward: {:.2f}".format(i, sum(episode_lengths[-20:]) / 20))
if i % 10 == 0:
agent.update_target_network()
return episode_lengths
# 运行 DQN 算法
if __name__ == '__main__':
env = gym.make('CartPole-v0')
observation_space_size = env.observation_space.shape[0]
action_space_size = env.action_space.n
agent = DQNAgent(observation_space_size, action_space_size)
episode_lengths = train_dqn(env, agent, num_episodes=200)
env.close()
```
这个代码是一个简单的DQN实现,可以在CartPole环境中得到不错的结果。如果要在其他环境中使用,需要根据环境的状态空间和动作空间进行修改。
阅读全文