Path: blob/master/labml_nn/rl/game.py
4921 views
"""1---2title: Atari wrapper with multi-processing3summary: This implements the Atari games with multi-processing.4---56# Atari wrapper with multi-processing7"""8import multiprocessing9import multiprocessing.connection1011import cv212import gym13import numpy as np141516class Game:17"""18<a id="GameEnvironment"></a>1920## Game environment2122This is a wrapper for OpenAI gym game environment.23We do a few things here:24251. Apply the same action on four frames and get the last frame262. Convert observation frames to gray and scale it to (84, 84)273. Stack four frames of the last four actions284. Add episode information (total reward for the entire episode) for monitoring295. Restrict an episode to a single life (game has 5 lives, we reset after every single life)3031#### Observation format32Observation is tensor of size (4, 84, 84). It is four frames33(images of the game screen) stacked on first axis.34i.e, each channel is a frame.35"""3637def __init__(self, seed: int):38# create environment39self.env = gym.make('BreakoutNoFrameskip-v4')40self.env.seed(seed)4142# tensor for a stack of 4 frames43self.obs_4 = np.zeros((4, 84, 84))4445# buffer to keep the maximum of last 2 frames46self.obs_2_max = np.zeros((2, 84, 84))4748# keep track of the episode rewards49self.rewards = []50# and number of lives left51self.lives = 05253def step(self, action):54"""55### Step56Executes `action` for 4 time steps and57returns a tuple of (observation, reward, done, episode_info).5859* `observation`: stacked 4 frames (this frame and frames for last 3 actions)60* `reward`: total reward while the action was executed61* `done`: whether the episode finished (a life lost)62* `episode_info`: episode information if completed63"""6465reward = 0.66done = None6768# run for 4 steps69for i in range(4):70# execute the action in the OpenAI Gym environment71obs, r, done, info = self.env.step(action)7273if i >= 2:74self.obs_2_max[i % 2] = self._process_obs(obs)7576reward += r7778# get number of lives left79lives = self.env.unwrapped.ale.lives()80# reset if a life is lost81if lives < self.lives:82done = True83break8485# maintain rewards for each step86self.rewards.append(reward)8788if done:89# if finished, set episode information if episode is over, and reset90episode_info = {"reward": sum(self.rewards), "length": len(self.rewards)}91self.reset()92else:93episode_info = None9495# get the max of last two frames96obs = self.obs_2_max.max(axis=0)9798# push it to the stack of 4 frames99self.obs_4 = np.roll(self.obs_4, shift=-1, axis=0)100self.obs_4[-1] = obs101102return self.obs_4, reward, done, episode_info103104def reset(self):105"""106### Reset environment107Clean up episode info and 4 frame stack108"""109110# reset OpenAI Gym environment111obs = self.env.reset()112113# reset caches114obs = self._process_obs(obs)115for i in range(4):116self.obs_4[i] = obs117self.rewards = []118119self.lives = self.env.unwrapped.ale.lives()120121return self.obs_4122123@staticmethod124def _process_obs(obs):125"""126#### Process game frames127Convert game frames to gray and rescale to 84x84128"""129obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)130obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)131return obs132133134def worker_process(remote: multiprocessing.connection.Connection, seed: int):135"""136##Worker Process137138Each worker process runs this method139"""140141# create game142game = Game(seed)143144# wait for instructions from the connection and execute them145while True:146cmd, data = remote.recv()147if cmd == "step":148remote.send(game.step(data))149elif cmd == "reset":150remote.send(game.reset())151elif cmd == "close":152remote.close()153break154else:155raise NotImplementedError156157158class Worker:159"""160Creates a new worker and runs it in a separate process.161"""162163def __init__(self, seed):164self.child, parent = multiprocessing.Pipe()165self.process = multiprocessing.Process(target=worker_process, args=(parent, seed))166self.process.start()167168169170171