用python和pytorch框架实现一个深度强化学习的demo
时间: 2023-09-13 11:06:00 浏览: 241
基于Pytorch实现深度强化学习各种算法python源码+算法介绍(DQN、Q-Learning、Sarsa等14种).zip
5星 · 资源好评率100%
好的,我可以为您提供一个使用Python和PyTorch框架的深度强化学习的demo。这个demo使用DQN算法来训练一个智能体,使其在CartPole-v1任务上取得高分。以下是详细步骤:
1. 安装依赖库
```
pip install gym torch numpy
```
2. 导入必要的库
```python
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
```
3. 定义深度强化学习模型
```python
class DQN(nn.Module):
def __init__(self, obs_size, hidden_size, action_size):
super(DQN, self).__init__()
self.obs_size = obs_size
self.hidden_size = hidden_size
self.action_size = action_size
self.fc1 = nn.Linear(self.obs_size, self.hidden_size)
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
self.fc3 = nn.Linear(self.hidden_size, self.action_size)
def forward(self, obs):
x = nn.functional.relu(self.fc1(obs))
x = nn.functional.relu(self.fc2(x))
return self.fc3(x)
```
4. 定义训练函数
```python
def train_dqn(env, dqn, num_episodes=1000, batch_size=32, gamma=0.99, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
optimizer = optim.Adam(dqn.parameters(), lr=0.001)
criterion = nn.MSELoss()
memory = []
eps = eps_start
for episode in range(num_episodes):
obs = env.reset()
done = False
total_reward = 0.0
while not done:
if np.random.random() < eps:
action = env.action_space.sample()
else:
obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
q_values = dqn(obs_tensor)
action = torch.argmax(q_values, dim=1).item()
next_obs, reward, done, _ = env.step(action)
total_reward += reward
memory.append((obs, action, reward, next_obs, done))
if len(memory) > 10000:
memory.pop(0)
if len(memory) > batch_size:
batch = np.random.choice(len(memory), batch_size, replace=False)
obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*[memory[i] for i in batch])
obs_tensor = torch.tensor(obs_batch, dtype=torch.float32)
next_obs_tensor = torch.tensor(next_obs_batch, dtype=torch.float32)
action_tensor = torch.tensor(action_batch, dtype=torch.int64).unsqueeze(1)
reward_tensor = torch.tensor(reward_batch, dtype=torch.float32).unsqueeze(1)
done_tensor = torch.tensor(done_batch, dtype=torch.float32).unsqueeze(1)
q_values = dqn(obs_tensor).gather(1, action_tensor)
next_q_values = dqn(next_obs_tensor).max(1)[0].unsqueeze(1)
target_q_values = reward_tensor + gamma * next_q_values * (1 - done_tensor)
loss = criterion(q_values, target_q_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
obs = next_obs
eps = max(eps_end, eps_decay * eps)
if (episode + 1) % 100 == 0:
print("Episode:", episode + 1, "Total reward:", total_reward)
```
5. 创建环境和模型,开始训练
```python
env = gym.make("CartPole-v1")
dqn = DQN(env.observation_space.shape[0], 16, env.action_space.n)
train_dqn(env, dqn)
```
在训练完成后,您可以使用以下代码来测试模型:
```python
obs = env.reset()
done = False
total_reward = 0.0
while not done:
env.render()
obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
q_values = dqn(obs_tensor)
action = torch.argmax(q_values, dim=1).item()
obs, reward, done, _ = env.step(action)
total_reward += reward
print("Total reward:", total_reward)
env.close()
```
这个demo可以让您更好地了解深度强化学习的实现过程。您可以根据需要进行修改和优化,以适应其他环境和任务。
阅读全文