深度强化学习代码python
时间: 2023-07-02 10:23:46 浏览: 101
深度学习python代码
以下是一个简单的深度强化学习代码示例,使用Python编写,使用了Keras和CSDN Gym:
```python
import gym
import random
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
env = gym.make('CartPole-v1')
env.reset()
goal_steps = 500
score_requirement = 60
initial_games = 10000
def model_data_preparation():
training_data = []
accepted_scores = []
for game_index in range(initial_games):
score = 0
game_memory = []
previous_observation = []
for step_index in range(goal_steps):
action = random.randrange(0,2)
observation, reward, done, info = env.step(action)
if len(previous_observation) > 0:
game_memory.append([previous_observation, action])
previous_observation = observation
score += reward
if done:
break
if score >= score_requirement:
accepted_scores.append(score)
for data in game_memory:
if data[1] == 1:
output = [0,1]
elif data[1] == 0:
output = [1,0]
training_data.append([data[0], output])
env.reset()
print(accepted_scores)
return training_data
def build_model(input_size, output_size):
model = Sequential()
model.add(Dense(128, input_dim=input_size, activation='relu'))
model.add(Dense(52, activation='relu'))
model.add(Dense(output_size, activation='linear'))
model.compile(loss='mse', optimizer=Adam())
return model
def train_model(training_data):
X = np.array([i[0] for i in training_data]).reshape(-1,len(training_data[0][0]))
y = np.array([i[1] for i in training_data]).reshape(-1,len(training_data[0][1]))
model = build_model(input_size=len(X[0]), output_size=len(y[0]))
model.fit(X, y, epochs=10)
return model
training_data = model_data_preparation()
trained_model = train_model(training_data)
scores = []
choices = []
for each_game in range(100):
score = 0
game_memory = []
prev_obs = []
env.reset()
for _ in range(goal_steps):
env.render()
if len(prev_obs)==0:
action = random.randrange(0,2)
else:
action = np.argmax(trained_model.predict(prev_obs.reshape(-1,len(prev_obs)))[0])
choices.append(action)
new_observation, reward, done, info = env.step(action)
prev_obs = new_observation
game_memory.append([new_observation, action])
score+=reward
if done:
break
scores.append(score)
print(scores)
print('Average Score:',sum(scores)/len(scores))
print('Choice 1:{} Choice 0:{}'.format(choices.count(1)/len(choices),choices.count(0)/len(choices)))
```
这个代码使用CartPole-v1环境,目标是使杆子不掉下来。它使用了一个神经网络,通过学习之前的游戏来预测下一个动作。最后,代码输出了模型在100个游戏中的平均得分和选择每个动作的频率。
阅读全文