Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/models/pixel_cnn.py
1192 views
1
import torch
2
from torch import nn
3
4
# The first layer is the PixelCNN layer. This layer simply
5
# builds on the 2D convolutional layer, but includes masking.
6
class PixelConvLayer(nn.Conv2d):
7
"""
8
Implementation of Masked CNN Class as explained in A Oord et. al.
9
Taken from https://github.com/jzbontar/pixelcnn-pytorch
10
"""
11
12
def __init__(
13
self,
14
mask_type,
15
in_channels,
16
out_channels,
17
kernel_size,
18
padding=0,
19
stride=1,
20
padding_mode="zeros",
21
dilation=1,
22
groups=1,
23
bias=True,
24
device=None,
25
dtype=None,
26
):
27
super(PixelConvLayer, self).__init__(
28
in_channels,
29
out_channels,
30
kernel_size,
31
stride=stride,
32
padding=padding,
33
dilation=dilation,
34
groups=groups,
35
bias=bias,
36
padding_mode=padding_mode,
37
device=device,
38
dtype=dtype,
39
)
40
assert mask_type in {"A", "B"}
41
42
self.register_buffer("mask", self.weight.data.clone())
43
_, _, height, width = self.weight.size()
44
self.mask.fill_(1)
45
self.mask[:, :, height // 2, width // 2 + (mask_type == "B") :] = 0
46
self.mask[:, :, height // 2 + 1 :] = 0
47
48
def forward(self, x):
49
self.weight.data *= self.mask
50
return super(PixelConvLayer, self).forward(x)
51
52
53
class ResidualBlock(nn.Module):
54
def __init__(self, in_channels, **kwargs):
55
super(ResidualBlock, self).__init__(**kwargs)
56
57
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1)
58
59
self.bn1 = nn.BatchNorm2d(in_channels // 2)
60
self.relu = nn.ReLU()
61
62
self.pixelconv = PixelConvLayer(
63
in_channels=in_channels // 2, out_channels=in_channels // 2, kernel_size=3, mask_type="B", padding="same"
64
)
65
self.bn2 = nn.BatchNorm2d(in_channels // 2)
66
self.conv2 = nn.Conv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1)
67
self.bn3 = nn.BatchNorm2d(in_channels)
68
69
def forward(self, inputs):
70
x = self.relu(self.bn1(self.conv1(inputs)))
71
x = self.relu(self.bn2(self.pixelconv(x)))
72
x = self.relu(self.bn3(self.conv2(x)))
73
return inputs + x
74
75
76
class PixelCNN(nn.Module):
77
def __init__(self, channels, num_residual_blocks, num_pixelcnn_layers, K, **kwargs):
78
super(PixelCNN, self).__init__(**kwargs)
79
80
modules = []
81
modules.append(
82
nn.Sequential(
83
PixelConvLayer(mask_type="A", in_channels=K, out_channels=channels, kernel_size=7, padding="same"),
84
nn.BatchNorm2d(channels),
85
nn.ReLU(),
86
)
87
)
88
89
for _ in range(num_pixelcnn_layers):
90
modules.append(
91
nn.Sequential(
92
PixelConvLayer(
93
mask_type="B", in_channels=channels, out_channels=channels, kernel_size=3, padding="same"
94
),
95
nn.BatchNorm2d(channels),
96
nn.ReLU(),
97
)
98
)
99
100
for _ in range(num_residual_blocks):
101
modules.append(ResidualBlock(in_channels=channels))
102
103
modules.append(nn.Conv2d(in_channels=channels, out_channels=K, kernel_size=1, padding="valid"))
104
self.model = nn.Sequential(*modules)
105
106
def forward(self, input):
107
return self.model(input)
108
109
def save(self, path="./pixelcnn_model.ckpt"):
110
torch.save(self.state_dict(), path)
111
112
def load(self, path):
113
self.load_state_dict(torch.load(path))
114
115