Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch19/gridworld/gridworld_env.py
1247 views
1
# coding: utf-8
2
3
# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu
4
# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 2021
5
#
6
# Code Repository:
7
#
8
# Code License: MIT License (https://github.com/ LICENSE.txt)
9
10
#################################################################################
11
# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments
12
#################################################################################
13
14
# Script: gridworld_env.py
15
16
import numpy as np
17
from gym.envs.toy_text import discrete
18
from collections import defaultdict
19
import time
20
import pickle
21
import os
22
23
from gym.envs.classic_control import rendering
24
25
CELL_SIZE = 100
26
MARGIN = 10
27
28
29
def get_coords(row, col, loc='center'):
30
xc = (col + 1.5) * CELL_SIZE
31
yc = (row + 1.5) * CELL_SIZE
32
if loc == 'center':
33
return xc, yc
34
elif loc == 'interior_corners':
35
half_size = CELL_SIZE//2 - MARGIN
36
xl, xr = xc - half_size, xc + half_size
37
yt, yb = xc - half_size, xc + half_size
38
return [(xl, yt), (xr, yt), (xr, yb), (xl, yb)]
39
elif loc == 'interior_triangle':
40
x1, y1 = xc, yc + CELL_SIZE//3
41
x2, y2 = xc + CELL_SIZE//3, yc - CELL_SIZE//3
42
x3, y3 = xc - CELL_SIZE//3, yc - CELL_SIZE//3
43
return [(x1, y1), (x2, y2), (x3, y3)]
44
45
46
def draw_object(coords_list):
47
if len(coords_list) == 1: # -> circle
48
obj = rendering.make_circle(int(0.45*CELL_SIZE))
49
obj_transform = rendering.Transform()
50
obj.add_attr(obj_transform)
51
obj_transform.set_translation(*coords_list[0])
52
obj.set_color(0.2, 0.2, 0.2) # -> black
53
elif len(coords_list) == 3: # -> triangle
54
obj = rendering.FilledPolygon(coords_list)
55
obj.set_color(0.9, 0.6, 0.2) # -> yellow
56
elif len(coords_list) > 3: # -> polygon
57
obj = rendering.FilledPolygon(coords_list)
58
obj.set_color(0.4, 0.4, 0.8) # -> blue
59
return obj
60
61
62
class GridWorldEnv(discrete.DiscreteEnv):
63
def __init__(self, num_rows=4, num_cols=6, delay=0.05):
64
self.num_rows = num_rows
65
self.num_cols = num_cols
66
67
self.delay = delay
68
69
move_up = lambda row, col: (max(row - 1, 0), col)
70
move_down = lambda row, col: (min(row + 1, num_rows - 1), col)
71
move_left = lambda row, col: (row, max(col - 1, 0))
72
move_right = lambda row, col: (row, min(col + 1, num_cols - 1))
73
74
self.action_defs = {0: move_up, 1: move_right,
75
2: move_down, 3: move_left}
76
77
# Number of states/actions
78
nS = num_cols * num_rows
79
nA = len(self.action_defs)
80
self.grid2state_dict = {(s // num_cols, s % num_cols): s
81
for s in range(nS)}
82
self.state2grid_dict = {s: (s // num_cols, s % num_cols)
83
for s in range(nS)}
84
85
# Gold state
86
gold_cell = (num_rows // 2, num_cols - 2)
87
88
# Trap states
89
trap_cells = [((gold_cell[0] + 1), gold_cell[1]),
90
(gold_cell[0], gold_cell[1] - 1),
91
((gold_cell[0] - 1), gold_cell[1])]
92
93
gold_state = self.grid2state_dict[gold_cell]
94
trap_states = [self.grid2state_dict[(r, c)]
95
for (r, c) in trap_cells]
96
self.terminal_states = [gold_state] + trap_states
97
print(self.terminal_states)
98
99
# Build the transition probability
100
P = defaultdict(dict)
101
for s in range(nS):
102
row, col = self.state2grid_dict[s]
103
P[s] = defaultdict(list)
104
for a in range(nA):
105
action = self.action_defs[a]
106
next_s = self.grid2state_dict[action(row, col)]
107
108
# Terminal state
109
if self.is_terminal(next_s):
110
r = (1.0 if next_s == self.terminal_states[0]
111
else -1.0)
112
else:
113
r = 0.0
114
if self.is_terminal(s):
115
done = True
116
next_s = s
117
else:
118
done = False
119
P[s][a] = [(1.0, next_s, r, done)]
120
121
# Initial state distribution
122
isd = np.zeros(nS)
123
isd[0] = 1.0
124
125
super().__init__(nS, nA, P, isd)
126
127
self.viewer = None
128
self._build_display(gold_cell, trap_cells)
129
130
def is_terminal(self, state):
131
return state in self.terminal_states
132
133
def _build_display(self, gold_cell, trap_cells):
134
135
screen_width = (self.num_cols + 2) * CELL_SIZE
136
screen_height = (self.num_rows + 2) * CELL_SIZE
137
self.viewer = rendering.Viewer(screen_width,
138
screen_height)
139
140
all_objects = []
141
142
# List of border points' coordinates
143
bp_list = [
144
(CELL_SIZE - MARGIN, CELL_SIZE - MARGIN),
145
(screen_width - CELL_SIZE + MARGIN, CELL_SIZE - MARGIN),
146
(screen_width - CELL_SIZE + MARGIN,
147
screen_height - CELL_SIZE + MARGIN),
148
(CELL_SIZE - MARGIN, screen_height - CELL_SIZE + MARGIN)
149
]
150
border = rendering.PolyLine(bp_list, True)
151
border.set_linewidth(5)
152
all_objects.append(border)
153
154
# Vertical lines
155
for col in range(self.num_cols + 1):
156
x1, y1 = (col + 1) * CELL_SIZE, CELL_SIZE
157
x2, y2 = (col + 1) * CELL_SIZE, \
158
(self.num_rows + 1) * CELL_SIZE
159
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
160
all_objects.append(line)
161
162
# Horizontal lines
163
for row in range(self.num_rows + 1):
164
x1, y1 = CELL_SIZE, (row + 1) * CELL_SIZE
165
x2, y2 = (self.num_cols + 1) * CELL_SIZE, \
166
(row + 1) * CELL_SIZE
167
line = rendering.PolyLine([(x1, y1), (x2, y2)], False)
168
all_objects.append(line)
169
170
# Traps: --> circles
171
for cell in trap_cells:
172
trap_coords = get_coords(*cell, loc='center')
173
all_objects.append(draw_object([trap_coords]))
174
175
# Gold: --> triangle
176
gold_coords = get_coords(*gold_cell,
177
loc='interior_triangle')
178
all_objects.append(draw_object(gold_coords))
179
180
# Agent --> square or robot
181
if (os.path.exists('robot-coordinates.pkl') and CELL_SIZE == 100):
182
agent_coords = pickle.load(
183
open('robot-coordinates.pkl', 'rb'))
184
starting_coords = get_coords(0, 0, loc='center')
185
agent_coords += np.array(starting_coords)
186
else:
187
agent_coords = get_coords(0, 0, loc='interior_corners')
188
agent = draw_object(agent_coords)
189
self.agent_trans = rendering.Transform()
190
agent.add_attr(self.agent_trans)
191
all_objects.append(agent)
192
193
for obj in all_objects:
194
self.viewer.add_geom(obj)
195
196
def render(self, mode='human', done=False):
197
if done:
198
sleep_time = 1
199
else:
200
sleep_time = self.delay
201
x_coord = self.s % self.num_cols
202
y_coord = self.s // self.num_cols
203
x_coord = (x_coord + 0) * CELL_SIZE
204
y_coord = (y_coord + 0) * CELL_SIZE
205
self.agent_trans.set_translation(x_coord, y_coord)
206
rend = self.viewer.render(
207
return_rgb_array=(mode == 'rgb_array'))
208
time.sleep(sleep_time)
209
return rend
210
211
def close(self):
212
if self.viewer:
213
self.viewer.close()
214
self.viewer = None
215
216
217
if __name__ == '__main__':
218
env = GridWorldEnv(5, 6)
219
for i in range(1):
220
s = env.reset()
221
env.render(mode='human', done=False)
222
223
while True:
224
action = np.random.choice(env.nA)
225
res = env.step(action)
226
print('Action ', env.s, action, ' -> ', res)
227
env.render(mode='human', done=res[2])
228
if res[2]:
229
break
230
231
env.close()
232
233