Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/rl/dqn/model.py
4944 views
1
"""
2
---
3
title: Deep Q Network (DQN) Model
4
summary: Implementation of neural network model for Deep Q Network (DQN).
5
---
6
7
# Deep Q Network (DQN) Model
8
9
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb)
10
"""
11
12
import torch
13
from torch import nn
14
15
16
17
class Model(nn.Module):
18
"""
19
## Dueling Network ⚔️ Model for $Q$ Values
20
21
We are using a [dueling network](https://arxiv.org/abs/1511.06581)
22
to calculate Q-values.
23
Intuition behind dueling network architecture is that in most states
24
the action doesn't matter,
25
and in some states the action is significant. Dueling network allows
26
this to be represented very well.
27
28
\begin{align}
29
Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)
30
\\
31
\mathop{\mathbb{E}}_{a \sim \pi(s)}
32
\Big[
33
A^\pi(s, a)
34
\Big]
35
&= 0
36
\end{align}
37
38
So we create two networks for $V$ and $A$ and get $Q$ from them.
39
$$
40
Q(s, a) = V(s) +
41
\Big(
42
A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')
43
\Big)
44
$$
45
We share the initial layers of the $V$ and $A$ networks.
46
"""
47
48
def __init__(self):
49
super().__init__()
50
self.conv = nn.Sequential(
51
# The first convolution layer takes a
52
# $84\times84$ frame and produces a $20\times20$ frame
53
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
54
nn.ReLU(),
55
56
# The second convolution layer takes a
57
# $20\times20$ frame and produces a $9\times9$ frame
58
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
59
nn.ReLU(),
60
61
# The third convolution layer takes a
62
# $9\times9$ frame and produces a $7\times7$ frame
63
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
64
nn.ReLU(),
65
)
66
67
# A fully connected layer takes the flattened
68
# frame from third convolution layer, and outputs
69
# $512$ features
70
self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
71
self.activation = nn.ReLU()
72
73
# This head gives the state value $V$
74
self.state_value = nn.Sequential(
75
nn.Linear(in_features=512, out_features=256),
76
nn.ReLU(),
77
nn.Linear(in_features=256, out_features=1),
78
)
79
# This head gives the action value $A$
80
self.action_value = nn.Sequential(
81
nn.Linear(in_features=512, out_features=256),
82
nn.ReLU(),
83
nn.Linear(in_features=256, out_features=4),
84
)
85
86
def forward(self, obs: torch.Tensor):
87
# Convolution
88
h = self.conv(obs)
89
# Reshape for linear layers
90
h = h.reshape((-1, 7 * 7 * 64))
91
92
# Linear layer
93
h = self.activation(self.lin(h))
94
95
# $A$
96
action_value = self.action_value(h)
97
# $V$
98
state_value = self.state_value(h)
99
100
# $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$
101
action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)
102
# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$
103
q = state_value + action_score_centered
104
105
return q
106
107