Path: blob/main/ch19/gridworld/gridworld_env.py
1247 views
# coding: utf-812# Python Machine Learning, PyTorch Edition by Sebastian Raschka (https://sebastianraschka.com), Yuxi (Hayden) Liu3# (https://www.mlexample.com/) & Vahid Mirjalili (http://vahidmirjalili.com), Packt Publishing Ltd. 20214#5# Code Repository:6#7# Code License: MIT License (https://github.com/ LICENSE.txt)89#################################################################################10# Chapter 19 - Reinforcement Learning for Decision Making in Complex Environments11#################################################################################1213# Script: gridworld_env.py1415import numpy as np16from gym.envs.toy_text import discrete17from collections import defaultdict18import time19import pickle20import os2122from gym.envs.classic_control import rendering2324CELL_SIZE = 10025MARGIN = 10262728def get_coords(row, col, loc='center'):29xc = (col + 1.5) * CELL_SIZE30yc = (row + 1.5) * CELL_SIZE31if loc == 'center':32return xc, yc33elif loc == 'interior_corners':34half_size = CELL_SIZE//2 - MARGIN35xl, xr = xc - half_size, xc + half_size36yt, yb = xc - half_size, xc + half_size37return [(xl, yt), (xr, yt), (xr, yb), (xl, yb)]38elif loc == 'interior_triangle':39x1, y1 = xc, yc + CELL_SIZE//340x2, y2 = xc + CELL_SIZE//3, yc - CELL_SIZE//341x3, y3 = xc - CELL_SIZE//3, yc - CELL_SIZE//342return [(x1, y1), (x2, y2), (x3, y3)]434445def draw_object(coords_list):46if len(coords_list) == 1: # -> circle47obj = rendering.make_circle(int(0.45*CELL_SIZE))48obj_transform = rendering.Transform()49obj.add_attr(obj_transform)50obj_transform.set_translation(*coords_list[0])51obj.set_color(0.2, 0.2, 0.2) # -> black52elif len(coords_list) == 3: # -> triangle53obj = rendering.FilledPolygon(coords_list)54obj.set_color(0.9, 0.6, 0.2) # -> yellow55elif len(coords_list) > 3: # -> polygon56obj = rendering.FilledPolygon(coords_list)57obj.set_color(0.4, 0.4, 0.8) # -> blue58return obj596061class GridWorldEnv(discrete.DiscreteEnv):62def __init__(self, num_rows=4, num_cols=6, delay=0.05):63self.num_rows = num_rows64self.num_cols = num_cols6566self.delay = delay6768move_up = lambda row, col: (max(row - 1, 0), col)69move_down = lambda row, col: (min(row + 1, num_rows - 1), col)70move_left = lambda row, col: (row, max(col - 1, 0))71move_right = lambda row, col: (row, min(col + 1, num_cols - 1))7273self.action_defs = {0: move_up, 1: move_right,742: move_down, 3: move_left}7576# Number of states/actions77nS = num_cols * num_rows78nA = len(self.action_defs)79self.grid2state_dict = {(s // num_cols, s % num_cols): s80for s in range(nS)}81self.state2grid_dict = {s: (s // num_cols, s % num_cols)82for s in range(nS)}8384# Gold state85gold_cell = (num_rows // 2, num_cols - 2)8687# Trap states88trap_cells = [((gold_cell[0] + 1), gold_cell[1]),89(gold_cell[0], gold_cell[1] - 1),90((gold_cell[0] - 1), gold_cell[1])]9192gold_state = self.grid2state_dict[gold_cell]93trap_states = [self.grid2state_dict[(r, c)]94for (r, c) in trap_cells]95self.terminal_states = [gold_state] + trap_states96print(self.terminal_states)9798# Build the transition probability99P = defaultdict(dict)100for s in range(nS):101row, col = self.state2grid_dict[s]102P[s] = defaultdict(list)103for a in range(nA):104action = self.action_defs[a]105next_s = self.grid2state_dict[action(row, col)]106107# Terminal state108if self.is_terminal(next_s):109r = (1.0 if next_s == self.terminal_states[0]110else -1.0)111else:112r = 0.0113if self.is_terminal(s):114done = True115next_s = s116else:117done = False118P[s][a] = [(1.0, next_s, r, done)]119120# Initial state distribution121isd = np.zeros(nS)122isd[0] = 1.0123124super().__init__(nS, nA, P, isd)125126self.viewer = None127self._build_display(gold_cell, trap_cells)128129def is_terminal(self, state):130return state in self.terminal_states131132def _build_display(self, gold_cell, trap_cells):133134screen_width = (self.num_cols + 2) * CELL_SIZE135screen_height = (self.num_rows + 2) * CELL_SIZE136self.viewer = rendering.Viewer(screen_width,137screen_height)138139all_objects = []140141# List of border points' coordinates142bp_list = [143(CELL_SIZE - MARGIN, CELL_SIZE - MARGIN),144(screen_width - CELL_SIZE + MARGIN, CELL_SIZE - MARGIN),145(screen_width - CELL_SIZE + MARGIN,146screen_height - CELL_SIZE + MARGIN),147(CELL_SIZE - MARGIN, screen_height - CELL_SIZE + MARGIN)148]149border = rendering.PolyLine(bp_list, True)150border.set_linewidth(5)151all_objects.append(border)152153# Vertical lines154for col in range(self.num_cols + 1):155x1, y1 = (col + 1) * CELL_SIZE, CELL_SIZE156x2, y2 = (col + 1) * CELL_SIZE, \157(self.num_rows + 1) * CELL_SIZE158line = rendering.PolyLine([(x1, y1), (x2, y2)], False)159all_objects.append(line)160161# Horizontal lines162for row in range(self.num_rows + 1):163x1, y1 = CELL_SIZE, (row + 1) * CELL_SIZE164x2, y2 = (self.num_cols + 1) * CELL_SIZE, \165(row + 1) * CELL_SIZE166line = rendering.PolyLine([(x1, y1), (x2, y2)], False)167all_objects.append(line)168169# Traps: --> circles170for cell in trap_cells:171trap_coords = get_coords(*cell, loc='center')172all_objects.append(draw_object([trap_coords]))173174# Gold: --> triangle175gold_coords = get_coords(*gold_cell,176loc='interior_triangle')177all_objects.append(draw_object(gold_coords))178179# Agent --> square or robot180if (os.path.exists('robot-coordinates.pkl') and CELL_SIZE == 100):181agent_coords = pickle.load(182open('robot-coordinates.pkl', 'rb'))183starting_coords = get_coords(0, 0, loc='center')184agent_coords += np.array(starting_coords)185else:186agent_coords = get_coords(0, 0, loc='interior_corners')187agent = draw_object(agent_coords)188self.agent_trans = rendering.Transform()189agent.add_attr(self.agent_trans)190all_objects.append(agent)191192for obj in all_objects:193self.viewer.add_geom(obj)194195def render(self, mode='human', done=False):196if done:197sleep_time = 1198else:199sleep_time = self.delay200x_coord = self.s % self.num_cols201y_coord = self.s // self.num_cols202x_coord = (x_coord + 0) * CELL_SIZE203y_coord = (y_coord + 0) * CELL_SIZE204self.agent_trans.set_translation(x_coord, y_coord)205rend = self.viewer.render(206return_rgb_array=(mode == 'rgb_array'))207time.sleep(sleep_time)208return rend209210def close(self):211if self.viewer:212self.viewer.close()213self.viewer = None214215216if __name__ == '__main__':217env = GridWorldEnv(5, 6)218for i in range(1):219s = env.reset()220env.render(mode='human', done=False)221222while True:223action = np.random.choice(env.nA)224res = env.step(action)225print('Action ', env.s, action, ' -> ', res)226env.render(mode='human', done=res[2])227if res[2]:228break229230env.close()231232233