Path: blob/master/labml_nn/rl/dqn/experiment.py
4937 views
"""1---2title: DQN Experiment with Atari Breakout3summary: Implementation of DQN experiment with Atari Breakout4---56# DQN Experiment with Atari Breakout78This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym.9It runs the [game environments on multiple processes](../game.html) to sample efficiently.1011[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb)12"""1314import numpy as np15import torch1617from labml import tracker, experiment, logger, monit18from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam19from labml_nn.helpers.schedule import Piecewise20from labml_nn.rl.dqn import QFuncLoss21from labml_nn.rl.dqn.model import Model22from labml_nn.rl.dqn.replay_buffer import ReplayBuffer23from labml_nn.rl.game import Worker2425# Select device26if torch.cuda.is_available():27device = torch.device("cuda:0")28else:29device = torch.device("cpu")303132def obs_to_torch(obs: np.ndarray) -> torch.Tensor:33"""Scale observations from `[0, 255]` to `[0, 1]`"""34return torch.tensor(obs, dtype=torch.float32, device=device) / 255.353637class Trainer:38"""39## Trainer40"""4142def __init__(self, *,43updates: int, epochs: int,44n_workers: int, worker_steps: int, mini_batch_size: int,45update_target_model: int,46learning_rate: FloatDynamicHyperParam,47):48# number of workers49self.n_workers = n_workers50# steps sampled on each update51self.worker_steps = worker_steps52# number of training iterations53self.train_epochs = epochs5455# number of updates56self.updates = updates57# size of mini batch for training58self.mini_batch_size = mini_batch_size5960# update target network every 250 update61self.update_target_model = update_target_model6263# learning rate64self.learning_rate = learning_rate6566# exploration as a function of updates67self.exploration_coefficient = Piecewise(68[69(0, 1.0),70(25_000, 0.1),71(self.updates / 2, 0.01)72], outside_value=0.01)7374# $\beta$ for replay buffer as a function of updates75self.prioritized_replay_beta = Piecewise(76[77(0, 0.4),78(self.updates, 1)79], outside_value=1)8081# Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.82self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)8384# Model for sampling and training85self.model = Model().to(device)86# target model to get $\textcolor{orange}Q(s';\textcolor{orange}{\theta_i^{-}})$87self.target_model = Model().to(device)8889# create workers90self.workers = [Worker(47 + i) for i in range(self.n_workers)]9192# initialize tensors for observations93self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)9495# reset the workers96for worker in self.workers:97worker.child.send(("reset", None))9899# get the initial observations100for i, worker in enumerate(self.workers):101self.obs[i] = worker.child.recv()102103# loss function104self.loss_func = QFuncLoss(0.99)105# optimizer106self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)107108def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):109"""110#### $\epsilon$-greedy Sampling111When sampling actions we use a $\epsilon$-greedy strategy, where we112take a greedy action with probabiliy $1 - \epsilon$ and113take a random action with probability $\epsilon$.114We refer to $\epsilon$ as `exploration_coefficient`.115"""116117# Sampling doesn't need gradients118with torch.no_grad():119# Sample the action with highest Q-value. This is the greedy action.120greedy_action = torch.argmax(q_value, dim=-1)121# Uniformly sample and action122random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)123# Whether to chose greedy action or the random action124is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient125# Pick the action based on `is_choose_rand`126return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()127128def sample(self, exploration_coefficient: float):129"""### Sample data"""130131# This doesn't need gradients132with torch.no_grad():133# Sample `worker_steps`134for t in range(self.worker_steps):135# Get Q_values for the current observation136q_value = self.model(obs_to_torch(self.obs))137# Sample actions138actions = self._sample_action(q_value, exploration_coefficient)139140# Run sampled actions on each worker141for w, worker in enumerate(self.workers):142worker.child.send(("step", actions[w]))143144# Collect information from each worker145for w, worker in enumerate(self.workers):146# Get results after executing the actions147next_obs, reward, done, info = worker.child.recv()148149# Add transition to replay buffer150self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)151152# update episode information.153# collect episode info, which is available if an episode finished;154# this includes total reward and length of the episode -155# look at `Game` to see how it works.156if info:157tracker.add('reward', info['reward'])158tracker.add('length', info['length'])159160# update current observation161self.obs[w] = next_obs162163def train(self, beta: float):164"""165### Train the model166"""167for _ in range(self.train_epochs):168# Sample from priority replay buffer169samples = self.replay_buffer.sample(self.mini_batch_size, beta)170# Get the predicted Q-value171q_value = self.model(obs_to_torch(samples['obs']))172173# Get the Q-values of the next state for [Double Q-learning](index.html).174# Gradients shouldn't propagate for these175with torch.no_grad():176# Get $\textcolor{cyan}Q(s';\textcolor{cyan}{\theta_i})$177double_q_value = self.model(obs_to_torch(samples['next_obs']))178# Get $\textcolor{orange}Q(s';\textcolor{orange}{\theta_i^{-}})$179target_q_value = self.target_model(obs_to_torch(samples['next_obs']))180181# Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.182td_errors, loss = self.loss_func(q_value,183q_value.new_tensor(samples['action']),184double_q_value, target_q_value,185q_value.new_tensor(samples['done']),186q_value.new_tensor(samples['reward']),187q_value.new_tensor(samples['weights']))188189# Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$190new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6191# Update replay buffer priorities192self.replay_buffer.update_priorities(samples['indexes'], new_priorities)193194# Set learning rate195for pg in self.optimizer.param_groups:196pg['lr'] = self.learning_rate()197# Zero out the previously calculated gradients198self.optimizer.zero_grad()199# Calculate gradients200loss.backward()201# Clip gradients202torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)203# Update parameters based on gradients204self.optimizer.step()205206def run_training_loop(self):207"""208### Run training loop209"""210211# Last 100 episode information212tracker.set_queue('reward', 100, True)213tracker.set_queue('length', 100, True)214215# Copy to target network initially216self.target_model.load_state_dict(self.model.state_dict())217218for update in monit.loop(self.updates):219# $\epsilon$, exploration fraction220exploration = self.exploration_coefficient(update)221tracker.add('exploration', exploration)222# $\beta$ for prioritized replay223beta = self.prioritized_replay_beta(update)224tracker.add('beta', beta)225226# Sample with current policy227self.sample(exploration)228229# Start training after the buffer is full230if self.replay_buffer.is_full():231# Train the model232self.train(beta)233234# Periodically update target network235if update % self.update_target_model == 0:236self.target_model.load_state_dict(self.model.state_dict())237238# Save tracked indicators.239tracker.save()240# Add a new line to the screen periodically241if (update + 1) % 1_000 == 0:242logger.log()243244def destroy(self):245"""246### Destroy247Stop the workers248"""249for worker in self.workers:250worker.child.send(("close", None))251252253def main():254# Create the experiment255experiment.create(name='dqn')256257# Configurations258configs = {259# Number of updates260'updates': 1_000_000,261# Number of epochs to train the model with sampled data.262'epochs': 8,263# Number of worker processes264'n_workers': 8,265# Number of steps to run on each process for a single update266'worker_steps': 4,267# Mini batch size268'mini_batch_size': 32,269# Target model updating interval270'update_target_model': 250,271# Learning rate.272'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),273}274275# Configurations276experiment.configs(configs)277278# Initialize the trainer279m = Trainer(**configs)280# Run and monitor the experiment281with experiment.start():282m.run_training_loop()283# Stop the workers284m.destroy()285286287# ## Run it288if __name__ == "__main__":289main()290291292