Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/rl/ppo/experiment.py
4921 views
1
"""
2
---
3
title: PPO Experiment with Atari Breakout
4
summary: Annotated implementation to train a PPO agent on Atari Breakout game.
5
---
6
7
# PPO Experiment with Atari Breakout
8
9
This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym.
10
It runs the [game environments on multiple processes](../game.html) to sample efficiently.
11
12
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/ppo/experiment.ipynb)
13
"""
14
15
from typing import Dict
16
17
import numpy as np
18
import torch
19
from torch import nn
20
from torch import optim
21
from torch.distributions import Categorical
22
23
from labml import monit, tracker, logger, experiment
24
from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam
25
from labml_nn.rl.game import Worker
26
from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss
27
from labml_nn.rl.ppo.gae import GAE
28
29
# Select device
30
if torch.cuda.is_available():
31
device = torch.device("cuda:0")
32
else:
33
device = torch.device("cpu")
34
35
36
class Model(nn.Module):
37
"""
38
## Model
39
"""
40
41
def __init__(self):
42
super().__init__()
43
44
# The first convolution layer takes a
45
# 84x84 frame and produces a 20x20 frame
46
self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
47
48
# The second convolution layer takes a
49
# 20x20 frame and produces a 9x9 frame
50
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
51
52
# The third convolution layer takes a
53
# 9x9 frame and produces a 7x7 frame
54
self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
55
56
# A fully connected layer takes the flattened
57
# frame from third convolution layer, and outputs
58
# 512 features
59
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
60
61
# A fully connected layer to get logits for $\pi$
62
self.pi_logits = nn.Linear(in_features=512, out_features=4)
63
64
# A fully connected layer to get value function
65
self.value = nn.Linear(in_features=512, out_features=1)
66
67
#
68
self.activation = nn.ReLU()
69
70
def forward(self, obs: torch.Tensor):
71
h = self.activation(self.conv1(obs))
72
h = self.activation(self.conv2(h))
73
h = self.activation(self.conv3(h))
74
h = h.reshape((-1, 7 * 7 * 64))
75
76
h = self.activation(self.lin(h))
77
78
pi = Categorical(logits=self.pi_logits(h))
79
value = self.value(h).reshape(-1)
80
81
return pi, value
82
83
84
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
85
"""Scale observations from `[0, 255]` to `[0, 1]`"""
86
return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
87
88
89
class Trainer:
90
"""
91
## Trainer
92
"""
93
94
def __init__(self, *,
95
updates: int, epochs: IntDynamicHyperParam,
96
n_workers: int, worker_steps: int, batches: int,
97
value_loss_coef: FloatDynamicHyperParam,
98
entropy_bonus_coef: FloatDynamicHyperParam,
99
clip_range: FloatDynamicHyperParam,
100
learning_rate: FloatDynamicHyperParam,
101
):
102
# #### Configurations
103
104
# number of updates
105
self.updates = updates
106
# number of epochs to train the model with sampled data
107
self.epochs = epochs
108
# number of worker processes
109
self.n_workers = n_workers
110
# number of steps to run on each process for a single update
111
self.worker_steps = worker_steps
112
# number of mini batches
113
self.batches = batches
114
# total number of samples for a single update
115
self.batch_size = self.n_workers * self.worker_steps
116
# size of a mini batch
117
self.mini_batch_size = self.batch_size // self.batches
118
assert (self.batch_size % self.batches == 0)
119
120
# Value loss coefficient
121
self.value_loss_coef = value_loss_coef
122
# Entropy bonus coefficient
123
self.entropy_bonus_coef = entropy_bonus_coef
124
125
# Clipping range
126
self.clip_range = clip_range
127
# Learning rate
128
self.learning_rate = learning_rate
129
130
# #### Initialize
131
132
# create workers
133
self.workers = [Worker(47 + i) for i in range(self.n_workers)]
134
135
# initialize tensors for observations
136
self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
137
for worker in self.workers:
138
worker.child.send(("reset", None))
139
for i, worker in enumerate(self.workers):
140
self.obs[i] = worker.child.recv()
141
142
# model
143
self.model = Model().to(device)
144
145
# optimizer
146
self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)
147
148
# GAE with $\gamma = 0.99$ and $\lambda = 0.95$
149
self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)
150
151
# PPO Loss
152
self.ppo_loss = ClippedPPOLoss()
153
154
# Value Loss
155
self.value_loss = ClippedValueFunctionLoss()
156
157
def sample(self) -> Dict[str, torch.Tensor]:
158
"""
159
### Sample data with current policy
160
"""
161
162
rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
163
actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)
164
done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)
165
obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)
166
log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
167
values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)
168
169
with torch.no_grad():
170
# sample `worker_steps` from each worker
171
for t in range(self.worker_steps):
172
# `self.obs` keeps track of the last observation from each worker,
173
# which is the input for the model to sample the next action
174
obs[:, t] = self.obs
175
# sample actions from $\pi_{\theta_{OLD}}$ for each worker;
176
# this returns arrays of size `n_workers`
177
pi, v = self.model(obs_to_torch(self.obs))
178
values[:, t] = v.cpu().numpy()
179
a = pi.sample()
180
actions[:, t] = a.cpu().numpy()
181
log_pis[:, t] = pi.log_prob(a).cpu().numpy()
182
183
# run sampled actions on each worker
184
for w, worker in enumerate(self.workers):
185
worker.child.send(("step", actions[w, t]))
186
187
for w, worker in enumerate(self.workers):
188
# get results after executing the actions
189
self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()
190
191
# collect episode info, which is available if an episode finished;
192
# this includes total reward and length of the episode -
193
# look at `Game` to see how it works.
194
if info:
195
tracker.add('reward', info['reward'])
196
tracker.add('length', info['length'])
197
198
# Get value of after the final step
199
_, v = self.model(obs_to_torch(self.obs))
200
values[:, self.worker_steps] = v.cpu().numpy()
201
202
# calculate advantages
203
advantages = self.gae(done, rewards, values)
204
205
#
206
samples = {
207
'obs': obs,
208
'actions': actions,
209
'values': values[:, :-1],
210
'log_pis': log_pis,
211
'advantages': advantages
212
}
213
214
# samples are currently in `[workers, time_step]` table,
215
# we should flatten it for training
216
samples_flat = {}
217
for k, v in samples.items():
218
v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])
219
if k == 'obs':
220
samples_flat[k] = obs_to_torch(v)
221
else:
222
samples_flat[k] = torch.tensor(v, device=device)
223
224
return samples_flat
225
226
def train(self, samples: Dict[str, torch.Tensor]):
227
"""
228
### Train the model based on samples
229
"""
230
231
# It learns faster with a higher number of epochs,
232
# but becomes a little unstable; that is,
233
# the average episode reward does not monotonically increase
234
# over time.
235
# May be reducing the clipping range might solve it.
236
for _ in range(self.epochs()):
237
# shuffle for each epoch
238
indexes = torch.randperm(self.batch_size)
239
240
# for each mini batch
241
for start in range(0, self.batch_size, self.mini_batch_size):
242
# get mini batch
243
end = start + self.mini_batch_size
244
mini_batch_indexes = indexes[start: end]
245
mini_batch = {}
246
for k, v in samples.items():
247
mini_batch[k] = v[mini_batch_indexes]
248
249
# train
250
loss = self._calc_loss(mini_batch)
251
252
# Set learning rate
253
for pg in self.optimizer.param_groups:
254
pg['lr'] = self.learning_rate()
255
# Zero out the previously calculated gradients
256
self.optimizer.zero_grad()
257
# Calculate gradients
258
loss.backward()
259
# Clip gradients
260
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
261
# Update parameters based on gradients
262
self.optimizer.step()
263
264
@staticmethod
265
def _normalize(adv: torch.Tensor):
266
"""#### Normalize advantage function"""
267
return (adv - adv.mean()) / (adv.std() + 1e-8)
268
269
def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:
270
"""
271
### Calculate total loss
272
"""
273
274
# $R_t$ returns sampled from $\pi_{\theta_{OLD}}$
275
sampled_return = samples['values'] + samples['advantages']
276
277
# $\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$,
278
# where $\hat{A_t}$ is advantages sampled from $\pi_{\theta_{OLD}}$.
279
# Refer to sampling function in [Main class](#main) below
280
# for the calculation of $\hat{A}_t$.
281
sampled_normalized_advantage = self._normalize(samples['advantages'])
282
283
# Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$;
284
# we are treating observations as state
285
pi, value = self.model(samples['obs'])
286
287
# $-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$
288
log_pi = pi.log_prob(samples['actions'])
289
290
# Calculate policy loss
291
policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())
292
293
# Calculate Entropy Bonus
294
#
295
# $\mathcal{L}^{EB}(\theta) =
296
# \mathbb{E}\Bigl[ S\bigl[\pi_\theta\bigr] (s_t) \Bigr]$
297
entropy_bonus = pi.entropy()
298
entropy_bonus = entropy_bonus.mean()
299
300
# Calculate value function loss
301
value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())
302
303
# $\mathcal{L}^{CLIP+VF+EB} (\theta) =
304
# \mathcal{L}^{CLIP} (\theta) +
305
# c_1 \mathcal{L}^{VF} (\theta) - c_2 \mathcal{L}^{EB}(\theta)$
306
loss = (policy_loss
307
+ self.value_loss_coef() * value_loss
308
- self.entropy_bonus_coef() * entropy_bonus)
309
310
# for monitoring
311
approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()
312
313
# Add to tracker
314
tracker.add({'policy_reward': -policy_loss,
315
'value_loss': value_loss,
316
'entropy_bonus': entropy_bonus,
317
'kl_div': approx_kl_divergence,
318
'clip_fraction': self.ppo_loss.clip_fraction})
319
320
return loss
321
322
def run_training_loop(self):
323
"""
324
### Run training loop
325
"""
326
327
# last 100 episode information
328
tracker.set_queue('reward', 100, True)
329
tracker.set_queue('length', 100, True)
330
331
for update in monit.loop(self.updates):
332
# sample with current policy
333
samples = self.sample()
334
335
# train the model
336
self.train(samples)
337
338
# Save tracked indicators.
339
tracker.save()
340
# Add a new line to the screen periodically
341
if (update + 1) % 1_000 == 0:
342
logger.log()
343
344
def destroy(self):
345
"""
346
### Destroy
347
Stop the workers
348
"""
349
for worker in self.workers:
350
worker.child.send(("close", None))
351
352
353
def main():
354
# Create the experiment
355
experiment.create(name='ppo')
356
# Configurations
357
configs = {
358
# Number of updates
359
'updates': 10000,
360
# ⚙️ Number of epochs to train the model with sampled data.
361
# You can change this while the experiment is running.
362
'epochs': IntDynamicHyperParam(8),
363
# Number of worker processes
364
'n_workers': 8,
365
# Number of steps to run on each process for a single update
366
'worker_steps': 128,
367
# Number of mini batches
368
'batches': 4,
369
# ⚙️ Value loss coefficient.
370
# You can change this while the experiment is running.
371
'value_loss_coef': FloatDynamicHyperParam(0.5),
372
# ⚙️ Entropy bonus coefficient.
373
# You can change this while the experiment is running.
374
'entropy_bonus_coef': FloatDynamicHyperParam(0.01),
375
# ⚙️ Clip range.
376
'clip_range': FloatDynamicHyperParam(0.1),
377
# You can change this while the experiment is running.
378
# ⚙️ Learning rate.
379
'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),
380
}
381
382
experiment.configs(configs)
383
384
# Initialize the trainer
385
m = Trainer(**configs)
386
387
# Run and monitor the experiment
388
with experiment.start():
389
m.run_training_loop()
390
# Stop the workers
391
m.destroy()
392
393
394
# ## Run it
395
if __name__ == "__main__":
396
main()
397
398