Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch19/gridworld/qlearning.py
1247 views
1
# coding: utf-8
2
3
# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu
4
# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 2021
5
#
6
# Code Repository:
7
#
8
# Code License: MIT License (https://github.com/ /LICENSE.txt)
9
10
#################################################################################
11
# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments
12
#################################################################################
13
14
# Script: qlearning.py
15
16
from gridworld_env import GridWorldEnv
17
from agent import Agent
18
from collections import namedtuple
19
import matplotlib.pyplot as plt
20
import numpy as np
21
22
np.random.seed(1)
23
24
Transition = namedtuple(
25
'Transition', ('state', 'action', 'reward', 'next_state', 'done'))
26
27
28
def run_qlearning(agent, env, num_episodes=50):
29
history = []
30
for episode in range(num_episodes):
31
state = env.reset()
32
env.render(mode='human')
33
final_reward, n_moves = 0.0, 0
34
while True:
35
action = agent.choose_action(state)
36
next_s, reward, done, _ = env.step(action)
37
agent._learn(Transition(state, action, reward,
38
next_s, done))
39
env.render(mode='human', done=done)
40
state = next_s
41
n_moves += 1
42
if done:
43
break
44
final_reward = reward
45
history.append((n_moves, final_reward))
46
print(f'Episode {episode}: Reward {final_reward:.2} #Moves {n_moves}')
47
48
return history
49
50
51
def plot_learning_history(history):
52
fig = plt.figure(1, figsize=(14, 10))
53
ax = fig.add_subplot(2, 1, 1)
54
episodes = np.arange(len(history))
55
moves = np.array([h[0] for h in history])
56
plt.plot(episodes, moves, lw=4,
57
marker="o", markersize=10)
58
ax.tick_params(axis='both', which='major', labelsize=15)
59
plt.xlabel('Episodes', size=20)
60
plt.ylabel('# moves', size=20)
61
62
ax = fig.add_subplot(2, 1, 2)
63
rewards = np.array([h[1] for h in history])
64
plt.step(episodes, rewards, lw=4)
65
ax.tick_params(axis='both', which='major', labelsize=15)
66
plt.xlabel('Episodes', size=20)
67
plt.ylabel('Final rewards', size=20)
68
plt.savefig('q-learning-history.png', dpi=300)
69
plt.show()
70
71
72
if __name__ == '__main__':
73
env = GridWorldEnv(num_rows=5, num_cols=6)
74
agent = Agent(env)
75
history = run_qlearning(agent, env)
76
env.close()
77
78
plot_learning_history(history)
79
80