Path: blob/main/ch19/gridworld/qlearning.py
1247 views
# coding: utf-812# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu3# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 20214#5# Code Repository:6#7# Code License: MIT License (https://github.com/ /LICENSE.txt)89#################################################################################10# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments11#################################################################################1213# Script: qlearning.py1415from gridworld_env import GridWorldEnv16from agent import Agent17from collections import namedtuple18import matplotlib.pyplot as plt19import numpy as np2021np.random.seed(1)2223Transition = namedtuple(24'Transition', ('state', 'action', 'reward', 'next_state', 'done'))252627def run_qlearning(agent, env, num_episodes=50):28history = []29for episode in range(num_episodes):30state = env.reset()31env.render(mode='human')32final_reward, n_moves = 0.0, 033while True:34action = agent.choose_action(state)35next_s, reward, done, _ = env.step(action)36agent._learn(Transition(state, action, reward,37next_s, done))38env.render(mode='human', done=done)39state = next_s40n_moves += 141if done:42break43final_reward = reward44history.append((n_moves, final_reward))45print(f'Episode {episode}: Reward {final_reward:.2} #Moves {n_moves}')4647return history484950def plot_learning_history(history):51fig = plt.figure(1, figsize=(14, 10))52ax = fig.add_subplot(2, 1, 1)53episodes = np.arange(len(history))54moves = np.array([h[0] for h in history])55plt.plot(episodes, moves, lw=4,56marker="o", markersize=10)57ax.tick_params(axis='both', which='major', labelsize=15)58plt.xlabel('Episodes', size=20)59plt.ylabel('# moves', size=20)6061ax = fig.add_subplot(2, 1, 2)62rewards = np.array([h[1] for h in history])63plt.step(episodes, rewards, lw=4)64ax.tick_params(axis='both', which='major', labelsize=15)65plt.xlabel('Episodes', size=20)66plt.ylabel('Final rewards', size=20)67plt.savefig('q-learning-history.png', dpi=300)68plt.show()697071if __name__ == '__main__':72env = GridWorldEnv(num_rows=5, num_cols=6)73agent = Agent(env)74history = run_qlearning(agent, env)75env.close()7677plot_learning_history(history)787980