Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/generative/vq_vae.py
3236 views
1
"""
2
Title: Vector-Quantized Variational Autoencoders
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/07/21
5
Last modified: 2022/06/27
6
Description: Training a VQ-VAE for image reconstruction and codebook sampling for generation.
7
Accelerator: GPU
8
"""
9
10
"""
11
In this example, we develop a Vector Quantized Variational Autoencoder (VQ-VAE).
12
VQ-VAE was proposed in
13
[Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)
14
by van der Oord et al. In standard VAEs, the latent space is continuous and is sampled
15
from a Gaussian distribution. It is generally harder to learn such a continuous
16
distribution via gradient descent. VQ-VAEs, on the other hand,
17
operate on a discrete latent space, making the optimization problem simpler. It does so
18
by maintaining a discrete *codebook*. The codebook is developed by
19
discretizing the distance between continuous embeddings and the encoded
20
outputs. These discrete code words are then fed to the decoder, which is trained
21
to generate reconstructed samples.
22
23
For an overview of VQ-VAEs, please refer to the original paper and
24
[this video explanation](https://www.youtube.com/watch?v=VZFVUrYcig0).
25
If you need a refresher on VAEs, you can refer to
26
[this book chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/).
27
VQ-VAEs are one of the main recipes behind [DALL-E](https://openai.com/blog/dall-e/)
28
and the idea of a codebook is used in [VQ-GANs](https://arxiv.org/abs/2012.09841).
29
This example uses implementation details from the
30
[official VQ-VAE tutorial](https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb)
31
from DeepMind.
32
33
## Requirements
34
35
To run this example, you will need TensorFlow 2.5 or higher, as well as
36
TensorFlow Probability, which can be installed using the command below.
37
"""
38
39
"""shell
40
pip install -q tensorflow-probability
41
"""
42
43
"""
44
## Imports
45
"""
46
47
import numpy as np
48
import matplotlib.pyplot as plt
49
50
from tensorflow import keras
51
from tensorflow.keras import layers
52
import tensorflow_probability as tfp
53
import tensorflow as tf
54
55
"""
56
## `VectorQuantizer` layer
57
58
First, we implement a custom layer for the vector quantizer, which is the layer in between
59
the encoder and decoder. Consider an output from the encoder, with shape `(batch_size, height, width,
60
num_filters)`. The vector quantizer will first flatten this output, only keeping the
61
`num_filters` dimension intact. So, the shape would become `(batch_size * height * width,
62
num_filters)`. The rationale behind this is to treat the total number of filters as the size for
63
the latent embeddings.
64
65
An embedding table is then initialized to learn a codebook. We measure the L2-normalized
66
distance between the flattened encoder outputs and code words of this codebook. We take the
67
code that yields the minimum distance, and we apply one-hot encoding to achieve quantization.
68
This way, the code yielding the minimum distance to the corresponding encoder output is
69
mapped as one and the remaining codes are mapped as zeros.
70
71
Since the quantization process is not differentiable, we apply a
72
[straight-through estimator](https://www.hassanaskary.com/python/pytorch/deep%20learning/2020/09/19/intuitive-explanation-of-straight-through-estimators.html)
73
in between the decoder and the encoder, so that the decoder gradients are directly propagated
74
to the encoder. As the encoder and decoder share the same channel space, the decoder gradients are
75
still meaningful to the encoder.
76
"""
77
78
79
class VectorQuantizer(layers.Layer):
80
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
81
super().__init__(**kwargs)
82
self.embedding_dim = embedding_dim
83
self.num_embeddings = num_embeddings
84
85
# The `beta` parameter is best kept between [0.25, 2] as per the paper.
86
self.beta = beta
87
88
# Initialize the embeddings which we will quantize.
89
w_init = tf.random_uniform_initializer()
90
self.embeddings = tf.Variable(
91
initial_value=w_init(
92
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
93
),
94
trainable=True,
95
name="embeddings_vqvae",
96
)
97
98
def call(self, x):
99
# Calculate the input shape of the inputs and
100
# then flatten the inputs keeping `embedding_dim` intact.
101
input_shape = tf.shape(x)
102
flattened = tf.reshape(x, [-1, self.embedding_dim])
103
104
# Quantization.
105
encoding_indices = self.get_code_indices(flattened)
106
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
107
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
108
109
# Reshape the quantized values back to the original input shape
110
quantized = tf.reshape(quantized, input_shape)
111
112
# Calculate vector quantization loss and add that to the layer. You can learn more
113
# about adding losses to different layers here:
114
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
115
# the original paper to get a handle on the formulation of the loss function.
116
commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
117
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
118
self.add_loss(self.beta * commitment_loss + codebook_loss)
119
120
# Straight-through estimator.
121
quantized = x + tf.stop_gradient(quantized - x)
122
return quantized
123
124
def get_code_indices(self, flattened_inputs):
125
# Calculate L2-normalized distance between the inputs and the codes.
126
similarity = tf.matmul(flattened_inputs, self.embeddings)
127
distances = (
128
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
129
+ tf.reduce_sum(self.embeddings**2, axis=0)
130
- 2 * similarity
131
)
132
133
# Derive the indices for minimum distances.
134
encoding_indices = tf.argmin(distances, axis=1)
135
return encoding_indices
136
137
138
"""
139
**A note on straight-through estimation**:
140
141
This line of code does the straight-through estimation part: `quantized = x +
142
tf.stop_gradient(quantized - x)`. During backpropagation, `(quantized - x)` won't be
143
included in the computation graph and the gradients obtained for `quantized`
144
will be copied for `inputs`. Thanks to [this video](https://youtu.be/VZFVUrYcig0?t=1393)
145
for helping me understand this technique.
146
"""
147
148
"""
149
## Encoder and decoder
150
151
Now for the encoder and the decoder for the VQ-VAE. We will keep them small so
152
that their capacity is a good fit for the MNIST dataset. The implementation of the encoder and
153
decoder come from
154
[this example](https://keras.io/examples/generative/vae).
155
156
Note that activations _other than ReLU_ may not work for the encoder and decoder layers in the
157
quantization architecture: Leaky ReLU activated layers, for example, have proven difficult to
158
train, resulting in intermittent loss spikes that the model has trouble recovering from.
159
"""
160
161
162
def get_encoder(latent_dim=16):
163
encoder_inputs = keras.Input(shape=(28, 28, 1))
164
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(
165
encoder_inputs
166
)
167
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
168
encoder_outputs = layers.Conv2D(latent_dim, 1, padding="same")(x)
169
return keras.Model(encoder_inputs, encoder_outputs, name="encoder")
170
171
172
def get_decoder(latent_dim=16):
173
latent_inputs = keras.Input(shape=get_encoder(latent_dim).output.shape[1:])
174
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(
175
latent_inputs
176
)
177
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
178
decoder_outputs = layers.Conv2DTranspose(1, 3, padding="same")(x)
179
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
180
181
182
"""
183
## Standalone VQ-VAE model
184
"""
185
186
187
def get_vqvae(latent_dim=16, num_embeddings=64):
188
vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
189
encoder = get_encoder(latent_dim)
190
decoder = get_decoder(latent_dim)
191
inputs = keras.Input(shape=(28, 28, 1))
192
encoder_outputs = encoder(inputs)
193
quantized_latents = vq_layer(encoder_outputs)
194
reconstructions = decoder(quantized_latents)
195
return keras.Model(inputs, reconstructions, name="vq_vae")
196
197
198
get_vqvae().summary()
199
200
"""
201
Note that the output channels of the encoder should match the `latent_dim` for the vector
202
quantizer.
203
"""
204
205
"""
206
## Wrapping up the training loop inside `VQVAETrainer`
207
"""
208
209
210
class VQVAETrainer(keras.models.Model):
211
def __init__(self, train_variance, latent_dim=32, num_embeddings=128, **kwargs):
212
super().__init__(**kwargs)
213
self.train_variance = train_variance
214
self.latent_dim = latent_dim
215
self.num_embeddings = num_embeddings
216
217
self.vqvae = get_vqvae(self.latent_dim, self.num_embeddings)
218
219
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
220
self.reconstruction_loss_tracker = keras.metrics.Mean(
221
name="reconstruction_loss"
222
)
223
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
224
225
@property
226
def metrics(self):
227
return [
228
self.total_loss_tracker,
229
self.reconstruction_loss_tracker,
230
self.vq_loss_tracker,
231
]
232
233
def train_step(self, x):
234
with tf.GradientTape() as tape:
235
# Outputs from the VQ-VAE.
236
reconstructions = self.vqvae(x)
237
238
# Calculate the losses.
239
reconstruction_loss = (
240
tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
241
)
242
total_loss = reconstruction_loss + sum(self.vqvae.losses)
243
244
# Backpropagation.
245
grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
246
self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
247
248
# Loss tracking.
249
self.total_loss_tracker.update_state(total_loss)
250
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
251
self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
252
253
# Log results.
254
return {
255
"loss": self.total_loss_tracker.result(),
256
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
257
"vqvae_loss": self.vq_loss_tracker.result(),
258
}
259
260
261
"""
262
## Load and preprocess the MNIST dataset
263
"""
264
265
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
266
267
x_train = np.expand_dims(x_train, -1)
268
x_test = np.expand_dims(x_test, -1)
269
x_train_scaled = (x_train / 255.0) - 0.5
270
x_test_scaled = (x_test / 255.0) - 0.5
271
272
data_variance = np.var(x_train / 255.0)
273
274
"""
275
## Train the VQ-VAE model
276
"""
277
278
vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16, num_embeddings=128)
279
vqvae_trainer.compile(optimizer=keras.optimizers.Adam())
280
vqvae_trainer.fit(x_train_scaled, epochs=30, batch_size=128)
281
282
"""
283
## Reconstruction results on the test set
284
"""
285
286
287
def show_subplot(original, reconstructed):
288
plt.subplot(1, 2, 1)
289
plt.imshow(original.squeeze() + 0.5)
290
plt.title("Original")
291
plt.axis("off")
292
293
plt.subplot(1, 2, 2)
294
plt.imshow(reconstructed.squeeze() + 0.5)
295
plt.title("Reconstructed")
296
plt.axis("off")
297
298
plt.show()
299
300
301
trained_vqvae_model = vqvae_trainer.vqvae
302
idx = np.random.choice(len(x_test_scaled), 10)
303
test_images = x_test_scaled[idx]
304
reconstructions_test = trained_vqvae_model.predict(test_images)
305
306
for test_image, reconstructed_image in zip(test_images, reconstructions_test):
307
show_subplot(test_image, reconstructed_image)
308
309
"""
310
These results look decent. You are encouraged to play with different hyperparameters
311
(especially the number of embeddings and the dimensions of the embeddings) and observe how
312
they affect the results.
313
"""
314
315
"""
316
## Visualizing the discrete codes
317
"""
318
319
encoder = vqvae_trainer.vqvae.get_layer("encoder")
320
quantizer = vqvae_trainer.vqvae.get_layer("vector_quantizer")
321
322
encoded_outputs = encoder.predict(test_images)
323
flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1])
324
codebook_indices = quantizer.get_code_indices(flat_enc_outputs)
325
codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1])
326
327
for i in range(len(test_images)):
328
plt.subplot(1, 2, 1)
329
plt.imshow(test_images[i].squeeze() + 0.5)
330
plt.title("Original")
331
plt.axis("off")
332
333
plt.subplot(1, 2, 2)
334
plt.imshow(codebook_indices[i])
335
plt.title("Code")
336
plt.axis("off")
337
plt.show()
338
339
"""
340
The figure above shows that the discrete codes have been able to capture some
341
regularities from the dataset. Now, how do we sample from this codebook to create
342
novel images? Since these codes are discrete and we imposed a categorical distribution
343
on them, we cannot use them yet to generate anything meaningful until we can generate likely
344
sequences of codes that we can give to the decoder.
345
346
The authors use a PixelCNN to train these codes so that they can be used as powerful priors to
347
generate novel examples. PixelCNN was proposed in
348
[Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328)
349
by van der Oord et al. We borrow the implementation from
350
[this PixelCNN example](https://keras.io/examples/generative/pixelcnn/). It's an autoregressive
351
generative model where the outputs are conditional on the prior ones. In other words, a PixelCNN
352
generates an image on a pixel-by-pixel basis. For the purpose in this example, however, its task
353
is to generate code book indices instead of pixels directly. The trained VQ-VAE decoder is used
354
to map the indices generated by the PixelCNN back into the pixel space.
355
"""
356
357
"""
358
## PixelCNN hyperparameters
359
"""
360
361
num_residual_blocks = 2
362
num_pixelcnn_layers = 2
363
pixelcnn_input_shape = encoded_outputs.shape[1:-1]
364
print(f"Input shape of the PixelCNN: {pixelcnn_input_shape}")
365
366
"""
367
This input shape represents the reduction in the resolution performed by the encoder. With "same" padding,
368
this exactly halves the "resolution" of the output shape for each stride-2 convolution layer. So, with these
369
two layers, we end up with an encoder output tensor of 7x7 on axes 2 and 3, with the first axis as the batch
370
size and the last axis being the code book embedding size. Since the quantization layer in the autoencoder
371
maps these 7x7 tensors to indices of the code book, these output layer axis sizes must be matched by the
372
PixelCNN as the input shape. The task of the PixelCNN for this architecture is to generate _likely_ 7x7
373
arrangements of codebook indices.
374
375
Note that this shape is something to optimize for in larger-sized image domains, along with the code
376
book sizes. Since the PixelCNN is autoregressive, it needs to pass over each codebook index sequentially
377
in order to generate novel images from the codebook. Each stride-2 (or rather more correctly a
378
stride (2, 2)) convolution layer will divide the image generation time by four. Note, however, that there
379
is probably a lower bound on this part: when the number of codes for the image to reconstruct is too small,
380
it has insufficient information for the decoder to represent the level of detail in the image, so the
381
output quality will suffer. This can be amended at least to some extent by using a larger code book.
382
Since the autoregressive part of the image generation procedure uses codebook indices, there is far less of
383
a performance penalty on using a larger code book as the lookup time for a larger-sized code from a larger
384
code book is much smaller in comparison to iterating over a larger sequence of code book indices, although
385
the size of the code book does impact on the batch size that can pass through the image generation procedure.
386
Finding the sweet spot for this trade-off can require some architecture tweaking and could very well differ
387
per dataset.
388
"""
389
390
"""
391
## PixelCNN model
392
393
Majority of this comes from
394
[this example](https://keras.io/examples/generative/pixelcnn/).
395
396
## Notes
397
398
Thanks to [Rein van 't Veer](https://github.com/reinvantveer) for improving this example with
399
copy-edits and minor code clean-ups.
400
"""
401
402
403
# The first layer is the PixelCNN layer. This layer simply
404
# builds on the 2D convolutional layer, but includes masking.
405
class PixelConvLayer(layers.Layer):
406
def __init__(self, mask_type, **kwargs):
407
super().__init__()
408
self.mask_type = mask_type
409
self.conv = layers.Conv2D(**kwargs)
410
411
def build(self, input_shape):
412
# Build the conv2d layer to initialize kernel variables
413
self.conv.build(input_shape)
414
# Use the initialized kernel to create the mask
415
kernel_shape = self.conv.kernel.get_shape()
416
self.mask = np.zeros(shape=kernel_shape)
417
self.mask[: kernel_shape[0] // 2, ...] = 1.0
418
self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
419
if self.mask_type == "B":
420
self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0
421
422
def call(self, inputs):
423
self.conv.kernel.assign(self.conv.kernel * self.mask)
424
return self.conv(inputs)
425
426
427
# Next, we build our residual block layer.
428
# This is just a normal residual block, but based on the PixelConvLayer.
429
class ResidualBlock(keras.layers.Layer):
430
def __init__(self, filters, **kwargs):
431
super().__init__(**kwargs)
432
self.conv1 = keras.layers.Conv2D(
433
filters=filters, kernel_size=1, activation="relu"
434
)
435
self.pixel_conv = PixelConvLayer(
436
mask_type="B",
437
filters=filters // 2,
438
kernel_size=3,
439
activation="relu",
440
padding="same",
441
)
442
self.conv2 = keras.layers.Conv2D(
443
filters=filters, kernel_size=1, activation="relu"
444
)
445
446
def call(self, inputs):
447
x = self.conv1(inputs)
448
x = self.pixel_conv(x)
449
x = self.conv2(x)
450
return keras.layers.add([inputs, x])
451
452
453
pixelcnn_inputs = keras.Input(shape=pixelcnn_input_shape, dtype=tf.int32)
454
ohe = tf.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings)
455
x = PixelConvLayer(
456
mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
457
)(ohe)
458
459
for _ in range(num_residual_blocks):
460
x = ResidualBlock(filters=128)(x)
461
462
for _ in range(num_pixelcnn_layers):
463
x = PixelConvLayer(
464
mask_type="B",
465
filters=128,
466
kernel_size=1,
467
strides=1,
468
activation="relu",
469
padding="valid",
470
)(x)
471
472
out = keras.layers.Conv2D(
473
filters=vqvae_trainer.num_embeddings, kernel_size=1, strides=1, padding="valid"
474
)(x)
475
476
pixel_cnn = keras.Model(pixelcnn_inputs, out, name="pixel_cnn")
477
pixel_cnn.summary()
478
479
"""
480
## Prepare data to train the PixelCNN
481
482
We will train the PixelCNN to learn a categorical distribution of the discrete codes.
483
First, we will generate code indices using the encoder and vector quantizer we just
484
trained. Our training objective will be to minimize the crossentropy loss between these
485
indices and the PixelCNN outputs. Here, the number of categories is equal to the number
486
of embeddings present in our codebook (128 in our case). The PixelCNN model is
487
trained to learn a distribution (as opposed to minimizing the L1/L2 loss), which is where
488
it gets its generative capabilities from.
489
"""
490
491
# Generate the codebook indices.
492
encoded_outputs = encoder.predict(x_train_scaled)
493
flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1])
494
codebook_indices = quantizer.get_code_indices(flat_enc_outputs)
495
496
codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1])
497
print(f"Shape of the training data for PixelCNN: {codebook_indices.shape}")
498
499
"""
500
## PixelCNN training
501
"""
502
503
pixel_cnn.compile(
504
optimizer=keras.optimizers.Adam(3e-4),
505
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
506
metrics=["accuracy"],
507
)
508
pixel_cnn.fit(
509
x=codebook_indices,
510
y=codebook_indices,
511
batch_size=128,
512
epochs=30,
513
validation_split=0.1,
514
)
515
516
"""
517
We can improve these scores with more training and hyperparameter tuning.
518
"""
519
520
"""
521
## Codebook sampling
522
523
Now that our PixelCNN is trained, we can sample distinct codes from its outputs and pass
524
them to our decoder to generate novel images.
525
"""
526
527
# Create a mini sampler model.
528
inputs = layers.Input(shape=pixel_cnn.input_shape[1:])
529
outputs = pixel_cnn(inputs, training=False)
530
categorical_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical)
531
outputs = categorical_layer(outputs)
532
sampler = keras.Model(inputs, outputs)
533
534
"""
535
We now construct a prior to generate images. Here, we will generate 10 images.
536
"""
537
538
# Create an empty array of priors.
539
batch = 10
540
priors = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
541
batch, rows, cols = priors.shape
542
543
# Iterate over the priors because generation has to be done sequentially pixel by pixel.
544
for row in range(rows):
545
for col in range(cols):
546
# Feed the whole array and retrieving the pixel value probabilities for the next
547
# pixel.
548
probs = sampler.predict(priors)
549
# Use the probabilities to pick pixel values and append the values to the priors.
550
priors[:, row, col] = probs[:, row, col]
551
552
print(f"Prior shape: {priors.shape}")
553
554
"""
555
We can now use our decoder to generate the images.
556
"""
557
558
# Perform an embedding lookup.
559
pretrained_embeddings = quantizer.embeddings
560
priors_ohe = tf.one_hot(priors.astype("int32"), vqvae_trainer.num_embeddings).numpy()
561
quantized = tf.matmul(
562
priors_ohe.astype("float32"), pretrained_embeddings, transpose_b=True
563
)
564
quantized = tf.reshape(quantized, (-1, *(encoded_outputs.shape[1:])))
565
566
# Generate novel images.
567
decoder = vqvae_trainer.vqvae.get_layer("decoder")
568
generated_samples = decoder.predict(quantized)
569
570
for i in range(batch):
571
plt.subplot(1, 2, 1)
572
plt.imshow(priors[i])
573
plt.title("Code")
574
plt.axis("off")
575
576
plt.subplot(1, 2, 2)
577
plt.imshow(generated_samples[i].squeeze() + 0.5)
578
plt.title("Generated Sample")
579
plt.axis("off")
580
plt.show()
581
582
"""
583
We can enhance the quality of these generated samples by tweaking the PixelCNN.
584
"""
585
586
"""
587
## Additional notes
588
589
* After the VQ-VAE paper was initially released, the authors developed an exponential
590
moving averaging scheme to update the embeddings inside the quantizer. If you're
591
interested you can check out
592
[this snippet](https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py#L124).
593
* To further enhance the quality of the generated samples,
594
[VQ-VAE-2](https://arxiv.org/abs/1906.00446) was proposed that follows a cascaded
595
approach to learn the codebook and to generate the images.
596
"""
597
598