Path: blob/master/labml_nn/rl/ppo/gae.py
4921 views
"""1---2title: Generalized Advantage Estimation (GAE)3summary: A PyTorch implementation/tutorial of Generalized Advantage Estimation (GAE).4---56# Generalized Advantage Estimation (GAE)78This is a [PyTorch](https://pytorch.org) implementation of paper9[Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438).1011You can find an experiment that uses it [here](experiment.html).12"""1314import numpy as np151617class GAE:18def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):19self.lambda_ = lambda_20self.gamma = gamma21self.worker_steps = worker_steps22self.n_workers = n_workers2324def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:25"""26### Calculate advantages2728\begin{align}29\hat{A_t^{(1)}} &= r_t + \gamma V(s_{t+1}) - V(s)30\\31\hat{A_t^{(2)}} &= r_t + \gamma r_{t+1} +\gamma^2 V(s_{t+2}) - V(s)32\\33...34\\35\hat{A_t^{(\infty)}} &= r_t + \gamma r_{t+1} +\gamma^2 r_{t+2} + ... - V(s)36\end{align}3738$\hat{A_t^{(1)}}$ is high bias, low variance, whilst39$\hat{A_t^{(\infty)}}$ is unbiased, high variance.4041We take a weighted average of $\hat{A_t^{(k)}}$ to balance bias and variance.42This is called Generalized Advantage Estimation.43$$\hat{A_t} = \hat{A_t^{GAE}} = \frac{\sum_k w_k \hat{A_t^{(k)}}}{\sum_k w_k}$$44We set $w_k = \lambda^{k-1}$, this gives clean calculation for45$\hat{A_t}$4647\begin{align}48\delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)49\\50\hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... +51(\gamma \lambda)^{T - t + 1} \delta_{T - 1}52\\53&= \delta_t + \gamma \lambda \hat{A_{t+1}}54\end{align}55"""5657# advantages table58advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)59last_advantage = 06061# $V(s_{t+1})$62last_value = values[:, -1]6364for t in reversed(range(self.worker_steps)):65# mask if episode completed after step $t$66mask = 1.0 - done[:, t]67last_value = last_value * mask68last_advantage = last_advantage * mask69# $\delta_t$70delta = rewards[:, t] + self.gamma * last_value - values[:, t]7172# $\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$73last_advantage = delta + self.gamma * self.lambda_ * last_advantage7475#76advantages[:, t] = last_advantage7778last_value = values[:, t]7980# $\hat{A_t}$81return advantages828384