用q-learning走迷宫代码
时间: 2023-08-01 10:14:28 浏览: 116
下面是一个使用 Q-learning 算法解决迷宫问题的 Python 代码示例:
```python
import numpy as np
import random
# 定义迷宫环境
maze = np.array([
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 1, 0, 1, 1, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 1],
[0, 1, 0, 0, 1, 1, 0, 0],
[0, 1, 1, 1, 0, 0, 1, 1],
[0, 1, 0, 0, 0, 1, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 1],
[0, 1, 1, 1, 0, 1, 0, 0]
])
# 定义状态数量和动作数量
num_states = maze.shape[0] * maze.shape[1]
num_actions = 4
# 定义 Q 值表,初始化为 0
q_table = np.zeros((num_states, num_actions))
# 定义参数
lr = 0.8
gamma = 0.95
num_episodes = 1000
max_steps = 100
# 定义状态转换函数
def get_next_state(current_state, action):
row, col = divmod(current_state, maze.shape[1])
next_row, next_col = row, col
if action == 0: # 向上走
next_row -= 1
elif action == 1: # 向下走
next_row += 1
elif action == 2: # 向左走
next_col -= 1
elif action == 3: # 向右走
next_col += 1
if next_row < 0 or next_row >= maze.shape[0] or next_col < 0 or next_col >= maze.shape[1] or maze[next_row][next_col] == 1:
# 下一个状态越界或者是障碍物,则原地不动
return current_state
else:
# 否则返回下一个状态
return next_row * maze.shape[1] + next_col
# 训练 Q 值表
for i in range(num_episodes):
# 重置迷宫
current_state = 0
for j in range(max_steps):
# 选择动作
if random.uniform(0, 1) < 0.1:
# 以一定概率随机选择动作,以便探索新的状态和动作
action = random.randint(0, num_actions - 1)
else:
# 否则选择当前状态下 Q 值最大的动作
action = np.argmax(q_table[current_state])
# 执行动作,得到下一个状态和获得的奖励
next_state = get_next_state(current_state, action)
reward = 1 if next_state == num_states - 1 else 0 # 到达终点获得奖励
# 更新 Q 值表
q_table[current_state][action] += lr * (reward + gamma * np.max(q_table[next_state]) - q_table[current_state][action])
# 如果到达终点,结束训练
if next_state == num_states - 1:
break
# 更新当前状态
current_state = next_state
# 使用训练好的 Q 值表走迷宫
current_state = 0
while True:
# 选择当前状态下 Q 值最大的动作
action = np.argmax(q_table[current_state])
# 执行动作,得到下一个状态
next_state = get_next_state(current_state, action)
# 如果到达终点,结束
if next_state == num_states - 1:
break
# 更新当前状态
current_state = next_state
# 输出路径
path = []
while current_state != 0:
path.insert(0, current_state)
action = np.argmax(q_table[current_state])
current_state = get_next_state(current_state, action)
print("路径为:", path)
```
该代码通过 Q-learning 算法训练 Q 值表,并使用训练好的 Q 值表来找到从起点到终点的最短路径。运行该代码,可以得到迷宫的最短路径。
阅读全文