用python写个SARSA算法
时间: 2023-03-13 11:16:33 浏览: 143
我可以为您提供一份完整的 Python 代码来实现 SARSA 算法:def sarsa(env, num_episodes, alpha, gamma=1.0, plot_every=100): # Initialize action-value function (empty dictionary of arrays) Q = defaultdict(lambda: np.zeros(env.nA)) # Initialize performance monitor np.zeros(num_episodes) # Loop over episodes for i_episode in range(1, num_episodes+1): # Monitor progress if i_episode % 100 == 0: print("\rEpisode {}/{}".format(i_episode, num_episodes), end="") sys.stdout.flush() # Reset the environment and pick the first action state = env.reset() action = np.argmax(Q[state]) # One step in the environment while True: # Take a step next_state, reward, done, info = env.step(action) # Pick the next action if not done: next_action = np.argmax(Q[next_state]) # Update TD estimate of Q Q[state][action] = update_Q(Q[state][action], np.max(Q[next_state]), reward, alpha, gamma) # Sarsa update state = next_state action = next_action # Add new rewards to performance monitor episode_rewards[i_episode] = reward if done: break # The policy is improved implicitly by changing the Q dictionary if i_episode % plot_every == 0: plot_reward_per_episode(episode_rewards, plot_every) # End of sarsa function return Q
阅读全文