Path: blob/master/labml_nn/rl/ppo/experiment.py
4921 views
"""1---2title: PPO Experiment with Atari Breakout3summary: Annotated implementation to train a PPO agent on Atari Breakout game.4---56# PPO Experiment with Atari Breakout78This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout game on OpenAI Gym.9It runs the [game environments on multiple processes](../game.html) to sample efficiently.1011[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/ppo/experiment.ipynb)12"""1314from typing import Dict1516import numpy as np17import torch18from torch import nn19from torch import optim20from torch.distributions import Categorical2122from labml import monit, tracker, logger, experiment23from labml.configs import FloatDynamicHyperParam, IntDynamicHyperParam24from labml_nn.rl.game import Worker25from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss26from labml_nn.rl.ppo.gae import GAE2728# Select device29if torch.cuda.is_available():30device = torch.device("cuda:0")31else:32device = torch.device("cpu")333435class Model(nn.Module):36"""37## Model38"""3940def __init__(self):41super().__init__()4243# The first convolution layer takes a44# 84x84 frame and produces a 20x20 frame45self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)4647# The second convolution layer takes a48# 20x20 frame and produces a 9x9 frame49self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)5051# The third convolution layer takes a52# 9x9 frame and produces a 7x7 frame53self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)5455# A fully connected layer takes the flattened56# frame from third convolution layer, and outputs57# 512 features58self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)5960# A fully connected layer to get logits for $\pi$61self.pi_logits = nn.Linear(in_features=512, out_features=4)6263# A fully connected layer to get value function64self.value = nn.Linear(in_features=512, out_features=1)6566#67self.activation = nn.ReLU()6869def forward(self, obs: torch.Tensor):70h = self.activation(self.conv1(obs))71h = self.activation(self.conv2(h))72h = self.activation(self.conv3(h))73h = h.reshape((-1, 7 * 7 * 64))7475h = self.activation(self.lin(h))7677pi = Categorical(logits=self.pi_logits(h))78value = self.value(h).reshape(-1)7980return pi, value818283def obs_to_torch(obs: np.ndarray) -> torch.Tensor:84"""Scale observations from `[0, 255]` to `[0, 1]`"""85return torch.tensor(obs, dtype=torch.float32, device=device) / 255.868788class Trainer:89"""90## Trainer91"""9293def __init__(self, *,94updates: int, epochs: IntDynamicHyperParam,95n_workers: int, worker_steps: int, batches: int,96value_loss_coef: FloatDynamicHyperParam,97entropy_bonus_coef: FloatDynamicHyperParam,98clip_range: FloatDynamicHyperParam,99learning_rate: FloatDynamicHyperParam,100):101# #### Configurations102103# number of updates104self.updates = updates105# number of epochs to train the model with sampled data106self.epochs = epochs107# number of worker processes108self.n_workers = n_workers109# number of steps to run on each process for a single update110self.worker_steps = worker_steps111# number of mini batches112self.batches = batches113# total number of samples for a single update114self.batch_size = self.n_workers * self.worker_steps115# size of a mini batch116self.mini_batch_size = self.batch_size // self.batches117assert (self.batch_size % self.batches == 0)118119# Value loss coefficient120self.value_loss_coef = value_loss_coef121# Entropy bonus coefficient122self.entropy_bonus_coef = entropy_bonus_coef123124# Clipping range125self.clip_range = clip_range126# Learning rate127self.learning_rate = learning_rate128129# #### Initialize130131# create workers132self.workers = [Worker(47 + i) for i in range(self.n_workers)]133134# initialize tensors for observations135self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)136for worker in self.workers:137worker.child.send(("reset", None))138for i, worker in enumerate(self.workers):139self.obs[i] = worker.child.recv()140141# model142self.model = Model().to(device)143144# optimizer145self.optimizer = optim.Adam(self.model.parameters(), lr=2.5e-4)146147# GAE with $\gamma = 0.99$ and $\lambda = 0.95$148self.gae = GAE(self.n_workers, self.worker_steps, 0.99, 0.95)149150# PPO Loss151self.ppo_loss = ClippedPPOLoss()152153# Value Loss154self.value_loss = ClippedValueFunctionLoss()155156def sample(self) -> Dict[str, torch.Tensor]:157"""158### Sample data with current policy159"""160161rewards = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)162actions = np.zeros((self.n_workers, self.worker_steps), dtype=np.int32)163done = np.zeros((self.n_workers, self.worker_steps), dtype=np.bool)164obs = np.zeros((self.n_workers, self.worker_steps, 4, 84, 84), dtype=np.uint8)165log_pis = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)166values = np.zeros((self.n_workers, self.worker_steps + 1), dtype=np.float32)167168with torch.no_grad():169# sample `worker_steps` from each worker170for t in range(self.worker_steps):171# `self.obs` keeps track of the last observation from each worker,172# which is the input for the model to sample the next action173obs[:, t] = self.obs174# sample actions from $\pi_{\theta_{OLD}}$ for each worker;175# this returns arrays of size `n_workers`176pi, v = self.model(obs_to_torch(self.obs))177values[:, t] = v.cpu().numpy()178a = pi.sample()179actions[:, t] = a.cpu().numpy()180log_pis[:, t] = pi.log_prob(a).cpu().numpy()181182# run sampled actions on each worker183for w, worker in enumerate(self.workers):184worker.child.send(("step", actions[w, t]))185186for w, worker in enumerate(self.workers):187# get results after executing the actions188self.obs[w], rewards[w, t], done[w, t], info = worker.child.recv()189190# collect episode info, which is available if an episode finished;191# this includes total reward and length of the episode -192# look at `Game` to see how it works.193if info:194tracker.add('reward', info['reward'])195tracker.add('length', info['length'])196197# Get value of after the final step198_, v = self.model(obs_to_torch(self.obs))199values[:, self.worker_steps] = v.cpu().numpy()200201# calculate advantages202advantages = self.gae(done, rewards, values)203204#205samples = {206'obs': obs,207'actions': actions,208'values': values[:, :-1],209'log_pis': log_pis,210'advantages': advantages211}212213# samples are currently in `[workers, time_step]` table,214# we should flatten it for training215samples_flat = {}216for k, v in samples.items():217v = v.reshape(v.shape[0] * v.shape[1], *v.shape[2:])218if k == 'obs':219samples_flat[k] = obs_to_torch(v)220else:221samples_flat[k] = torch.tensor(v, device=device)222223return samples_flat224225def train(self, samples: Dict[str, torch.Tensor]):226"""227### Train the model based on samples228"""229230# It learns faster with a higher number of epochs,231# but becomes a little unstable; that is,232# the average episode reward does not monotonically increase233# over time.234# May be reducing the clipping range might solve it.235for _ in range(self.epochs()):236# shuffle for each epoch237indexes = torch.randperm(self.batch_size)238239# for each mini batch240for start in range(0, self.batch_size, self.mini_batch_size):241# get mini batch242end = start + self.mini_batch_size243mini_batch_indexes = indexes[start: end]244mini_batch = {}245for k, v in samples.items():246mini_batch[k] = v[mini_batch_indexes]247248# train249loss = self._calc_loss(mini_batch)250251# Set learning rate252for pg in self.optimizer.param_groups:253pg['lr'] = self.learning_rate()254# Zero out the previously calculated gradients255self.optimizer.zero_grad()256# Calculate gradients257loss.backward()258# Clip gradients259torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)260# Update parameters based on gradients261self.optimizer.step()262263@staticmethod264def _normalize(adv: torch.Tensor):265"""#### Normalize advantage function"""266return (adv - adv.mean()) / (adv.std() + 1e-8)267268def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor:269"""270### Calculate total loss271"""272273# $R_t$ returns sampled from $\pi_{\theta_{OLD}}$274sampled_return = samples['values'] + samples['advantages']275276# $\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$,277# where $\hat{A_t}$ is advantages sampled from $\pi_{\theta_{OLD}}$.278# Refer to sampling function in [Main class](#main) below279# for the calculation of $\hat{A}_t$.280sampled_normalized_advantage = self._normalize(samples['advantages'])281282# Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$;283# we are treating observations as state284pi, value = self.model(samples['obs'])285286# $-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$287log_pi = pi.log_prob(samples['actions'])288289# Calculate policy loss290policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range())291292# Calculate Entropy Bonus293#294# $\mathcal{L}^{EB}(\theta) =295# \mathbb{E}\Bigl[ S\bigl[\pi_\theta\bigr] (s_t) \Bigr]$296entropy_bonus = pi.entropy()297entropy_bonus = entropy_bonus.mean()298299# Calculate value function loss300value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range())301302# $\mathcal{L}^{CLIP+VF+EB} (\theta) =303# \mathcal{L}^{CLIP} (\theta) +304# c_1 \mathcal{L}^{VF} (\theta) - c_2 \mathcal{L}^{EB}(\theta)$305loss = (policy_loss306+ self.value_loss_coef() * value_loss307- self.entropy_bonus_coef() * entropy_bonus)308309# for monitoring310approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean()311312# Add to tracker313tracker.add({'policy_reward': -policy_loss,314'value_loss': value_loss,315'entropy_bonus': entropy_bonus,316'kl_div': approx_kl_divergence,317'clip_fraction': self.ppo_loss.clip_fraction})318319return loss320321def run_training_loop(self):322"""323### Run training loop324"""325326# last 100 episode information327tracker.set_queue('reward', 100, True)328tracker.set_queue('length', 100, True)329330for update in monit.loop(self.updates):331# sample with current policy332samples = self.sample()333334# train the model335self.train(samples)336337# Save tracked indicators.338tracker.save()339# Add a new line to the screen periodically340if (update + 1) % 1_000 == 0:341logger.log()342343def destroy(self):344"""345### Destroy346Stop the workers347"""348for worker in self.workers:349worker.child.send(("close", None))350351352def main():353# Create the experiment354experiment.create(name='ppo')355# Configurations356configs = {357# Number of updates358'updates': 10000,359# ⚙️ Number of epochs to train the model with sampled data.360# You can change this while the experiment is running.361'epochs': IntDynamicHyperParam(8),362# Number of worker processes363'n_workers': 8,364# Number of steps to run on each process for a single update365'worker_steps': 128,366# Number of mini batches367'batches': 4,368# ⚙️ Value loss coefficient.369# You can change this while the experiment is running.370'value_loss_coef': FloatDynamicHyperParam(0.5),371# ⚙️ Entropy bonus coefficient.372# You can change this while the experiment is running.373'entropy_bonus_coef': FloatDynamicHyperParam(0.01),374# ⚙️ Clip range.375'clip_range': FloatDynamicHyperParam(0.1),376# You can change this while the experiment is running.377# ⚙️ Learning rate.378'learning_rate': FloatDynamicHyperParam(1e-3, (0, 1e-3)),379}380381experiment.configs(configs)382383# Initialize the trainer384m = Trainer(**configs)385386# Run and monitor the experiment387with experiment.start():388m.run_training_loop()389# Stop the workers390m.destroy()391392393# ## Run it394if __name__ == "__main__":395main()396397398