Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/rl/game.py
4921 views
1
"""
2
---
3
title: Atari wrapper with multi-processing
4
summary: This implements the Atari games with multi-processing.
5
---
6
7
# Atari wrapper with multi-processing
8
"""
9
import multiprocessing
10
import multiprocessing.connection
11
12
import cv2
13
import gym
14
import numpy as np
15
16
17
class Game:
18
"""
19
<a id="GameEnvironment"></a>
20
21
## Game environment
22
23
This is a wrapper for OpenAI gym game environment.
24
We do a few things here:
25
26
1. Apply the same action on four frames and get the last frame
27
2. Convert observation frames to gray and scale it to (84, 84)
28
3. Stack four frames of the last four actions
29
4. Add episode information (total reward for the entire episode) for monitoring
30
5. Restrict an episode to a single life (game has 5 lives, we reset after every single life)
31
32
#### Observation format
33
Observation is tensor of size (4, 84, 84). It is four frames
34
(images of the game screen) stacked on first axis.
35
i.e, each channel is a frame.
36
"""
37
38
def __init__(self, seed: int):
39
# create environment
40
self.env = gym.make('BreakoutNoFrameskip-v4')
41
self.env.seed(seed)
42
43
# tensor for a stack of 4 frames
44
self.obs_4 = np.zeros((4, 84, 84))
45
46
# buffer to keep the maximum of last 2 frames
47
self.obs_2_max = np.zeros((2, 84, 84))
48
49
# keep track of the episode rewards
50
self.rewards = []
51
# and number of lives left
52
self.lives = 0
53
54
def step(self, action):
55
"""
56
### Step
57
Executes `action` for 4 time steps and
58
returns a tuple of (observation, reward, done, episode_info).
59
60
* `observation`: stacked 4 frames (this frame and frames for last 3 actions)
61
* `reward`: total reward while the action was executed
62
* `done`: whether the episode finished (a life lost)
63
* `episode_info`: episode information if completed
64
"""
65
66
reward = 0.
67
done = None
68
69
# run for 4 steps
70
for i in range(4):
71
# execute the action in the OpenAI Gym environment
72
obs, r, done, info = self.env.step(action)
73
74
if i >= 2:
75
self.obs_2_max[i % 2] = self._process_obs(obs)
76
77
reward += r
78
79
# get number of lives left
80
lives = self.env.unwrapped.ale.lives()
81
# reset if a life is lost
82
if lives < self.lives:
83
done = True
84
break
85
86
# maintain rewards for each step
87
self.rewards.append(reward)
88
89
if done:
90
# if finished, set episode information if episode is over, and reset
91
episode_info = {"reward": sum(self.rewards), "length": len(self.rewards)}
92
self.reset()
93
else:
94
episode_info = None
95
96
# get the max of last two frames
97
obs = self.obs_2_max.max(axis=0)
98
99
# push it to the stack of 4 frames
100
self.obs_4 = np.roll(self.obs_4, shift=-1, axis=0)
101
self.obs_4[-1] = obs
102
103
return self.obs_4, reward, done, episode_info
104
105
def reset(self):
106
"""
107
### Reset environment
108
Clean up episode info and 4 frame stack
109
"""
110
111
# reset OpenAI Gym environment
112
obs = self.env.reset()
113
114
# reset caches
115
obs = self._process_obs(obs)
116
for i in range(4):
117
self.obs_4[i] = obs
118
self.rewards = []
119
120
self.lives = self.env.unwrapped.ale.lives()
121
122
return self.obs_4
123
124
@staticmethod
125
def _process_obs(obs):
126
"""
127
#### Process game frames
128
Convert game frames to gray and rescale to 84x84
129
"""
130
obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
131
obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
132
return obs
133
134
135
def worker_process(remote: multiprocessing.connection.Connection, seed: int):
136
"""
137
##Worker Process
138
139
Each worker process runs this method
140
"""
141
142
# create game
143
game = Game(seed)
144
145
# wait for instructions from the connection and execute them
146
while True:
147
cmd, data = remote.recv()
148
if cmd == "step":
149
remote.send(game.step(data))
150
elif cmd == "reset":
151
remote.send(game.reset())
152
elif cmd == "close":
153
remote.close()
154
break
155
else:
156
raise NotImplementedError
157
158
159
class Worker:
160
"""
161
Creates a new worker and runs it in a separate process.
162
"""
163
164
def __init__(self, seed):
165
self.child, parent = multiprocessing.Pipe()
166
self.process = multiprocessing.Process(target=worker_process, args=(parent, seed))
167
self.process.start()
168
169
170
171