dqn算法python代码
时间: 2023-10-02 16:12:29 浏览: 105
深度强化学习DQN算法python源码.zip
5星 · 资源好评率100%
以下是DQN算法的Python代码实现:
```python
import random
import numpy as np
import tensorflow as tf
from collections import deque
class DQNAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.memory = deque(maxlen=2000)
self.gamma = 0.95
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.learning_rate = 0.001
self.model = self._build_model()
def _build_model(self):
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(24, input_dim=self.state_size, activation='relu'))
model.add(tf.keras.layers.Dense(24, activation='relu'))
model.add(tf.keras.layers.Dense(self.action_size, activation='linear'))
model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.learning_rate))
return model
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_size)
act_values = self.model.predict(state)
return np.argmax(act_values[0])
def replay(self, batch_size):
minibatch = random.sample(self.memory, batch_size)
for state, action, reward, next_state, done in minibatch:
target = reward
if not done:
target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0]))
target_f = self.model.predict(state)
target_f[0][action] = target
self.model.fit(state, target_f, epochs=1, verbose=0)
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
def load(self, name):
self.model.load_weights(name)
def save(self, name):
self.model.save_weights(name)
```
这个代码实现了一个DQNAgent类,包括以下方法:
- `__init__(self, state_size, action_size)`:初始化DQNAgent的实例,其中state_size表示状态空间的维度,action_size表示动作空间的维度。
- `_build_model(self)`:构建DQN模型,包括输入层、2个隐藏层和输出层。
- `remember(self, state, action, reward, next_state, done)`:记忆当前的状态、动作、奖励、下一个状态和是否终止。
- `act(self, state)`:选择当前状态下的动作,根据当前状态和epsilon参数来选择epsilon-greedy策略或者随机选择动作。
- `replay(self, batch_size)`:从记忆库中取出batch_size个样本进行训练,使用当前模型的预测值与target值计算MSE损失并进行反向传播更新模型参数。
- `load(self, name)`:从文件中加载模型参数。
- `save(self, name)`:将模型参数保存到文件中。
使用DQNAgent类进行训练的示例代码:
```python
import gym
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)
EPISODES = 1000
for e in range(EPISODES):
state = env.reset()
state = np.reshape(state, [1, state_size])
done = False
score = 0
while not done:
action = agent.act(state)
next_state, reward, done, _ = env.step(action)
next_state = np.reshape(next_state, [1, state_size])
agent.remember(state, action, reward, next_state, done)
state = next_state
score += reward
if done:
print("episode: {}/{}, score: {}".format(e, EPISODES, score))
break
if len(agent.memory) > batch_size:
agent.replay(batch_size)
agent.save("cartpole-dqn.h5")
```
在这个示例中,我们使用OpenAI Gym中的CartPole环境进行训练,DQNAgent的状态空间维度为4,动作空间维度为2。我们训练1000个episode,并在episode结束时输出当前episode的得分。每个episode结束后,我们从记忆库中随机选择batch_size个样本进行训练。最后,我们将训练好的模型参数保存到cartpole-dqn.h5文件中。
阅读全文