Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/rl/dqn/experiment.py
4937 views
1
"""
2
---
3
title: DQN Experiment with Atari Breakout
4
summary: Implementation of DQN experiment with Atari Breakout
5
---
6
7
# DQN Experiment with Atari Breakout
8
9
This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym.
10
It runs the [game environments on multiple processes](../game.html) to sample efficiently.
11
12
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb)
13
"""
14
15
import numpy as np
16
import torch
17
18
from labml import tracker, experiment, logger, monit
19
from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
20
from labml_nn.helpers.schedule import Piecewise
21
from labml_nn.rl.dqn import QFuncLoss
22
from labml_nn.rl.dqn.model import Model
23
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
24
from labml_nn.rl.game import Worker
25
26
# Select device
27
if torch.cuda.is_available():
28
device = torch.device("cuda:0")
29
else:
30
device = torch.device("cpu")
31
32
33
def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
34
"""Scale observations from `[0, 255]` to `[0, 1]`"""
35
return torch.tensor(obs, dtype=torch.float32, device=device) / 255.
36
37
38
class Trainer:
39
"""
40
## Trainer
41
"""
42
43
def __init__(self, *,
44
updates: int, epochs: int,
45
n_workers: int, worker_steps: int, mini_batch_size: int,
46
update_target_model: int,
47
learning_rate: FloatDynamicHyperParam,
48
):
49
# number of workers
50
self.n_workers = n_workers
51
# steps sampled on each update
52
self.worker_steps = worker_steps
53
# number of training iterations
54
self.train_epochs = epochs
55
56
# number of updates
57
self.updates = updates
58
# size of mini batch for training
59
self.mini_batch_size = mini_batch_size
60
61
# update target network every 250 update
62
self.update_target_model = update_target_model
63
64
# learning rate
65
self.learning_rate = learning_rate
66
67
# exploration as a function of updates
68
self.exploration_coefficient = Piecewise(
69
[
70
(0, 1.0),
71
(25_000, 0.1),
72
(self.updates / 2, 0.01)
73
], outside_value=0.01)
74
75
# $\beta$ for replay buffer as a function of updates
76
self.prioritized_replay_beta = Piecewise(
77
[
78
(0, 0.4),
79
(self.updates, 1)
80
], outside_value=1)
81
82
# Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
83
self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)
84
85
# Model for sampling and training
86
self.model = Model().to(device)
87
# target model to get $\textcolor{orange}Q(s';\textcolor{orange}{\theta_i^{-}})$
88
self.target_model = Model().to(device)
89
90
# create workers
91
self.workers = [Worker(47 + i) for i in range(self.n_workers)]
92
93
# initialize tensors for observations
94
self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
95
96
# reset the workers
97
for worker in self.workers:
98
worker.child.send(("reset", None))
99
100
# get the initial observations
101
for i, worker in enumerate(self.workers):
102
self.obs[i] = worker.child.recv()
103
104
# loss function
105
self.loss_func = QFuncLoss(0.99)
106
# optimizer
107
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
108
109
def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
110
"""
111
#### $\epsilon$-greedy Sampling
112
When sampling actions we use a $\epsilon$-greedy strategy, where we
113
take a greedy action with probabiliy $1 - \epsilon$ and
114
take a random action with probability $\epsilon$.
115
We refer to $\epsilon$ as `exploration_coefficient`.
116
"""
117
118
# Sampling doesn't need gradients
119
with torch.no_grad():
120
# Sample the action with highest Q-value. This is the greedy action.
121
greedy_action = torch.argmax(q_value, dim=-1)
122
# Uniformly sample and action
123
random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
124
# Whether to chose greedy action or the random action
125
is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
126
# Pick the action based on `is_choose_rand`
127
return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()
128
129
def sample(self, exploration_coefficient: float):
130
"""### Sample data"""
131
132
# This doesn't need gradients
133
with torch.no_grad():
134
# Sample `worker_steps`
135
for t in range(self.worker_steps):
136
# Get Q_values for the current observation
137
q_value = self.model(obs_to_torch(self.obs))
138
# Sample actions
139
actions = self._sample_action(q_value, exploration_coefficient)
140
141
# Run sampled actions on each worker
142
for w, worker in enumerate(self.workers):
143
worker.child.send(("step", actions[w]))
144
145
# Collect information from each worker
146
for w, worker in enumerate(self.workers):
147
# Get results after executing the actions
148
next_obs, reward, done, info = worker.child.recv()
149
150
# Add transition to replay buffer
151
self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)
152
153
# update episode information.
154
# collect episode info, which is available if an episode finished;
155
# this includes total reward and length of the episode -
156
# look at `Game` to see how it works.
157
if info:
158
tracker.add('reward', info['reward'])
159
tracker.add('length', info['length'])
160
161
# update current observation
162
self.obs[w] = next_obs
163
164
def train(self, beta: float):
165
"""
166
### Train the model
167
"""
168
for _ in range(self.train_epochs):
169
# Sample from priority replay buffer
170
samples = self.replay_buffer.sample(self.mini_batch_size, beta)
171
# Get the predicted Q-value
172
q_value = self.model(obs_to_torch(samples['obs']))
173
174
# Get the Q-values of the next state for [Double Q-learning](index.html).
175
# Gradients shouldn't propagate for these
176
with torch.no_grad():
177
# Get $\textcolor{cyan}Q(s';\textcolor{cyan}{\theta_i})$
178
double_q_value = self.model(obs_to_torch(samples['next_obs']))
179
# Get $\textcolor{orange}Q(s';\textcolor{orange}{\theta_i^{-}})$
180
target_q_value = self.target_model(obs_to_torch(samples['next_obs']))
181
182
# Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
183
td_errors, loss = self.loss_func(q_value,
184
q_value.new_tensor(samples['action']),
185
double_q_value, target_q_value,
186
q_value.new_tensor(samples['done']),
187
q_value.new_tensor(samples['reward']),
188
q_value.new_tensor(samples['weights']))
189
190
# Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
191
new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
192
# Update replay buffer priorities
193
self.replay_buffer.update_priorities(samples['indexes'], new_priorities)
194
195
# Set learning rate
196
for pg in self.optimizer.param_groups:
197
pg['lr'] = self.learning_rate()
198
# Zero out the previously calculated gradients
199
self.optimizer.zero_grad()
200
# Calculate gradients
201
loss.backward()
202
# Clip gradients
203
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
204
# Update parameters based on gradients
205
self.optimizer.step()
206
207
def run_training_loop(self):
208
"""
209
### Run training loop
210
"""
211
212
# Last 100 episode information
213
tracker.set_queue('reward', 100, True)
214
tracker.set_queue('length', 100, True)
215
216
# Copy to target network initially
217
self.target_model.load_state_dict(self.model.state_dict())
218
219
for update in monit.loop(self.updates):
220
# $\epsilon$, exploration fraction
221
exploration = self.exploration_coefficient(update)
222
tracker.add('exploration', exploration)
223
# $\beta$ for prioritized replay
224
beta = self.prioritized_replay_beta(update)
225
tracker.add('beta', beta)
226
227
# Sample with current policy
228
self.sample(exploration)
229
230
# Start training after the buffer is full
231
if self.replay_buffer.is_full():
232
# Train the model
233
self.train(beta)
234
235
# Periodically update target network
236
if update % self.update_target_model == 0:
237
self.target_model.load_state_dict(self.model.state_dict())
238
239
# Save tracked indicators.
240
tracker.save()
241
# Add a new line to the screen periodically
242
if (update + 1) % 1_000 == 0:
243
logger.log()
244
245
def destroy(self):
246
"""
247
### Destroy
248
Stop the workers
249
"""
250
for worker in self.workers:
251
worker.child.send(("close", None))
252
253
254
def main():
255
# Create the experiment
256
experiment.create(name='dqn')
257
258
# Configurations
259
configs = {
260
# Number of updates
261
'updates': 1_000_000,
262
# Number of epochs to train the model with sampled data.
263
'epochs': 8,
264
# Number of worker processes
265
'n_workers': 8,
266
# Number of steps to run on each process for a single update
267
'worker_steps': 4,
268
# Mini batch size
269
'mini_batch_size': 32,
270
# Target model updating interval
271
'update_target_model': 250,
272
# Learning rate.
273
'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
274
}
275
276
# Configurations
277
experiment.configs(configs)
278
279
# Initialize the trainer
280
m = Trainer(**configs)
281
# Run and monitor the experiment
282
with experiment.start():
283
m.run_training_loop()
284
# Stop the workers
285
m.destroy()
286
287
288
# ## Run it
289
if __name__ == "__main__":
290
main()
291
292