Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/rl/ppo/gae.py
4921 views
1
"""
2
---
3
title: Generalized Advantage Estimation (GAE)
4
summary: A PyTorch implementation/tutorial of Generalized Advantage Estimation (GAE).
5
---
6
7
# Generalized Advantage Estimation (GAE)
8
9
This is a [PyTorch](https://pytorch.org) implementation of paper
10
[Generalized Advantage Estimation](https://arxiv.org/abs/1506.02438).
11
12
You can find an experiment that uses it [here](experiment.html).
13
"""
14
15
import numpy as np
16
17
18
class GAE:
19
def __init__(self, n_workers: int, worker_steps: int, gamma: float, lambda_: float):
20
self.lambda_ = lambda_
21
self.gamma = gamma
22
self.worker_steps = worker_steps
23
self.n_workers = n_workers
24
25
def __call__(self, done: np.ndarray, rewards: np.ndarray, values: np.ndarray) -> np.ndarray:
26
"""
27
### Calculate advantages
28
29
\begin{align}
30
\hat{A_t^{(1)}} &= r_t + \gamma V(s_{t+1}) - V(s)
31
\\
32
\hat{A_t^{(2)}} &= r_t + \gamma r_{t+1} +\gamma^2 V(s_{t+2}) - V(s)
33
\\
34
...
35
\\
36
\hat{A_t^{(\infty)}} &= r_t + \gamma r_{t+1} +\gamma^2 r_{t+2} + ... - V(s)
37
\end{align}
38
39
$\hat{A_t^{(1)}}$ is high bias, low variance, whilst
40
$\hat{A_t^{(\infty)}}$ is unbiased, high variance.
41
42
We take a weighted average of $\hat{A_t^{(k)}}$ to balance bias and variance.
43
This is called Generalized Advantage Estimation.
44
$$\hat{A_t} = \hat{A_t^{GAE}} = \frac{\sum_k w_k \hat{A_t^{(k)}}}{\sum_k w_k}$$
45
We set $w_k = \lambda^{k-1}$, this gives clean calculation for
46
$\hat{A_t}$
47
48
\begin{align}
49
\delta_t &= r_t + \gamma V(s_{t+1}) - V(s_t)
50
\\
51
\hat{A_t} &= \delta_t + \gamma \lambda \delta_{t+1} + ... +
52
(\gamma \lambda)^{T - t + 1} \delta_{T - 1}
53
\\
54
&= \delta_t + \gamma \lambda \hat{A_{t+1}}
55
\end{align}
56
"""
57
58
# advantages table
59
advantages = np.zeros((self.n_workers, self.worker_steps), dtype=np.float32)
60
last_advantage = 0
61
62
# $V(s_{t+1})$
63
last_value = values[:, -1]
64
65
for t in reversed(range(self.worker_steps)):
66
# mask if episode completed after step $t$
67
mask = 1.0 - done[:, t]
68
last_value = last_value * mask
69
last_advantage = last_advantage * mask
70
# $\delta_t$
71
delta = rewards[:, t] + self.gamma * last_value - values[:, t]
72
73
# $\hat{A_t} = \delta_t + \gamma \lambda \hat{A_{t+1}}$
74
last_advantage = delta + self.gamma * self.lambda_ * last_advantage
75
76
#
77
advantages[:, t] = last_advantage
78
79
last_value = values[:, t]
80
81
# $\hat{A_t}$
82
return advantages
83
84