dqn python
时间: 2023-10-19 15:33:15 浏览: 156
DQN (Deep Q-Network) is a popular reinforcement learning algorithm used for training agents to make decisions in environments with discrete action spaces. In Python, you can implement DQN using popular deep learning libraries such as TensorFlow or PyTorch.
Here's a simple example of how to implement DQN in Python using the PyTorch library:
1. Install the required libraries:
```python
pip install gym torch torchvision numpy
```
2. Import the necessary libraries:
```python
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
```
3. Define the Q-network:
```python
class QNetwork(nn.Module):
def __init__(self, state_size, action_size):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_size, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
```
4. Initialize the environment and hyperparameters:
```python
env = gym.make("CartPole-v0")
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
batch_size = 32
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.01
memory = []
model = QNetwork(state_size, action_size)
optimizer = optim.Adam(model.parameters(), lr=0.001)
```
5. Define the replay memory and epsilon-greedy exploration:
```python
def remember(state, action, reward, next_state, done):
memory.append((state, action, reward, next_state, done))
def choose_action(state):
if np.random.rand() <= epsilon:
return env.action_space.sample()
else:
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
q_values = model(state)
return torch.argmax(q_values).item()
```
6. Define the training loop:
```python
def replay_experience():
if len(memory) < batch_size:
return
batch = np.random.choice(len(memory), batch_size, replace=False)
states, actions, rewards, next_states, dones = zip(*[memory[i] for i in batch])
states = torch.tensor(states, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.long)
rewards = torch.tensor(rewards, dtype=torch.float32)
next_states = torch.tensor(next_states, dtype=torch.float32)
dones = torch.tensor(dones, dtype=torch.float32)
q_values = model(states)
next_q_values = model(next_states)
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
next_q_values = next_q_values.max(1)[0]
expected_q_values = rewards + gamma * next_q_values * (1 - dones)
loss = F.smooth_l1_loss(q_values, expected_q_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def train_agent(num_episodes):
global epsilon
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
while True:
action = choose_action(state)
next_state, reward, done, _ = env.step(action)
remember(state, action, reward, next_state, done)
total_reward += reward
state = next_state
replay_experience()
if done:
break
epsilon = max(epsilon_min, epsilon_decay * epsilon)
if (episode + 1) % 10 == 0:
print(f"Episode: {episode + 1}, Reward: {total_reward}")
env.close()
```
7. Train the agent:
```python
num_episodes = 100
train_agent(num_episodes)
```
This is a basic implementation of DQN for the CartPole-v0 environment. You can further optimize and customize the code based on your specific requirements or try it with different environments.
阅读全文