Path: blob/master/labml_nn/rl/dqn/model.py
4944 views
"""1---2title: Deep Q Network (DQN) Model3summary: Implementation of neural network model for Deep Q Network (DQN).4---56# Deep Q Network (DQN) Model78[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb)9"""1011import torch12from torch import nn13141516class Model(nn.Module):17"""18## Dueling Network ⚔️ Model for $Q$ Values1920We are using a [dueling network](https://arxiv.org/abs/1511.06581)21to calculate Q-values.22Intuition behind dueling network architecture is that in most states23the action doesn't matter,24and in some states the action is significant. Dueling network allows25this to be represented very well.2627\begin{align}28Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a)29\\30\mathop{\mathbb{E}}_{a \sim \pi(s)}31\Big[32A^\pi(s, a)33\Big]34&= 035\end{align}3637So we create two networks for $V$ and $A$ and get $Q$ from them.38$$39Q(s, a) = V(s) +40\Big(41A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')42\Big)43$$44We share the initial layers of the $V$ and $A$ networks.45"""4647def __init__(self):48super().__init__()49self.conv = nn.Sequential(50# The first convolution layer takes a51# $84\times84$ frame and produces a $20\times20$ frame52nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),53nn.ReLU(),5455# The second convolution layer takes a56# $20\times20$ frame and produces a $9\times9$ frame57nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),58nn.ReLU(),5960# The third convolution layer takes a61# $9\times9$ frame and produces a $7\times7$ frame62nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),63nn.ReLU(),64)6566# A fully connected layer takes the flattened67# frame from third convolution layer, and outputs68# $512$ features69self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)70self.activation = nn.ReLU()7172# This head gives the state value $V$73self.state_value = nn.Sequential(74nn.Linear(in_features=512, out_features=256),75nn.ReLU(),76nn.Linear(in_features=256, out_features=1),77)78# This head gives the action value $A$79self.action_value = nn.Sequential(80nn.Linear(in_features=512, out_features=256),81nn.ReLU(),82nn.Linear(in_features=256, out_features=4),83)8485def forward(self, obs: torch.Tensor):86# Convolution87h = self.conv(obs)88# Reshape for linear layers89h = h.reshape((-1, 7 * 7 * 64))9091# Linear layer92h = self.activation(self.lin(h))9394# $A$95action_value = self.action_value(h)96# $V$97state_value = self.state_value(h)9899# $A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')$100action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)101# $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$102q = state_value + action_score_centered103104return q105106107