Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/book1/20/ae_mnist_gdl_tf.ipynb
1192 views
Kernel: Python 3

Open In Colab

Autoencoder on mnist using 2d latent space

We fit the model to MNIST and use a 2d latent space. Code is based on chapter 3 of David Foster's book: https://github.com/davidADSP/GDL_code/. We have added all the necessary libraries into a single notebook. We have modified it to work with TF 2.0.

try: # %tensorflow_version only exists in Colab. %tensorflow_version 2.x IS_COLAB = True except Exception: IS_COLAB = False # TensorFlow ≥2.0 is required import tensorflow as tf from tensorflow import keras assert tf.__version__ >= "2.0" if not tf.config.list_physical_devices("GPU"): print("No GPU was detected. DNNs can be very slow without a GPU.") if IS_COLAB: print("Go to Runtime > Change runtime and select a GPU hardware accelerator.")
# Standard Python libraries from __future__ import absolute_import, division, print_function, unicode_literals import os import time import numpy as np import glob import matplotlib.pyplot as plt import PIL import imageio from IPython import display import sklearn from time import time np.random.seed(0)
# https://github.com/davidADSP/GDL_code/blob/master/utils/loaders.py from tensorflow.keras.datasets import mnist def load_mnist(): (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.astype("float32") / 255.0 x_train = x_train.reshape(x_train.shape + (1,)) x_test = x_test.astype("float32") / 255.0 x_test = x_test.reshape(x_test.shape + (1,)) return (x_train, y_train), (x_test, y_test) (x_train, y_train), (x_test, y_test) = load_mnist() print(x_train.shape)
(60000, 28, 28, 1)
# Utility functions # https://github.com/davidADSP/GDL_code/blob/master/utils/callbacks.py from tensorflow.keras.callbacks import Callback, LearningRateScheduler import numpy as np import matplotlib.pyplot as plt import os #### CALLBACKS class CustomCallback(Callback): def __init__(self, run_folder, print_every_n_batches, initial_epoch, vae): self.epoch = initial_epoch self.run_folder = run_folder self.print_every_n_batches = print_every_n_batches self.vae = vae def on_batch_end(self, batch, logs={}): if batch % self.print_every_n_batches == 0: z_new = np.random.normal(size=(1, self.vae.z_dim)) reconst = self.vae.decoder.predict(np.array(z_new))[0].squeeze() filepath = os.path.join( self.run_folder, "images", "img_" + str(self.epoch).zfill(3) + "_" + str(batch) + ".jpg" ) if len(reconst.shape) == 2: plt.imsave(filepath, reconst, cmap="gray_r") else: plt.imsave(filepath, reconst) def on_epoch_begin(self, epoch, logs={}): self.epoch += 1 if False: z_new = np.random.normal(size=(1, self.vae.z_dim)) reconst = self.vae.decoder.predict(np.array(z_new))[0].squeeze() plt.figure() if len(reconst.shape) == 2: plt.imshow(reconst, cmap="gray") else: plt.imshow(reconst) plt.suptitle("end of epoch {}".format(self.epoch)) plt.show() def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1): """ Wrapper function to create a LearningRateScheduler with step decay schedule. """ def schedule(epoch): new_lr = initial_lr * (decay_factor ** np.floor(epoch / step_size)) return new_lr return LearningRateScheduler(schedule)
def load_model(model_class, folder): with open(os.path.join(folder, "params.pkl"), "rb") as f: params = pickle.load(f) model = model_class(*params) model.load_weights(os.path.join(folder, "weights/weights.h5")) return model
from tensorflow.keras.layers import ( Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ) from tensorflow.keras.models import Model from tensorflow.keras import backend as K from tensorflow.keras.optimizers import Adam from tensorflow.keras.callbacks import ModelCheckpoint from tensorflow.keras.utils import plot_model # from utils.callbacks import CustomCallback, step_decay_schedule import numpy as np import json import os import pickle class Autoencoder: def __init__( self, input_dim, encoder_conv_filters, encoder_conv_kernel_size, encoder_conv_strides, decoder_conv_t_filters, decoder_conv_t_kernel_size, decoder_conv_t_strides, z_dim, use_batch_norm=False, use_dropout=False, ): self.name = "autoencoder" self.input_dim = input_dim self.encoder_conv_filters = encoder_conv_filters self.encoder_conv_kernel_size = encoder_conv_kernel_size self.encoder_conv_strides = encoder_conv_strides self.decoder_conv_t_filters = decoder_conv_t_filters self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size self.decoder_conv_t_strides = decoder_conv_t_strides self.z_dim = z_dim self.use_batch_norm = use_batch_norm self.use_dropout = use_dropout self.n_layers_encoder = len(encoder_conv_filters) self.n_layers_decoder = len(decoder_conv_t_filters) self._build() def _build(self): ### THE ENCODER encoder_input = Input(shape=self.input_dim, name="encoder_input") x = encoder_input for i in range(self.n_layers_encoder): conv_layer = Conv2D( filters=self.encoder_conv_filters[i], kernel_size=self.encoder_conv_kernel_size[i], strides=self.encoder_conv_strides[i], padding="same", name="encoder_conv_" + str(i), ) x = conv_layer(x) x = LeakyReLU()(x) if self.use_batch_norm: x = BatchNormalization()(x) if self.use_dropout: x = Dropout(rate=0.25)(x) shape_before_flattening = K.int_shape(x)[1:] x = Flatten()(x) encoder_output = Dense(self.z_dim, name="encoder_output")(x) self.encoder = Model(encoder_input, encoder_output) ### THE DECODER decoder_input = Input(shape=(self.z_dim,), name="decoder_input") x = Dense(np.prod(shape_before_flattening))(decoder_input) x = Reshape(shape_before_flattening)(x) for i in range(self.n_layers_decoder): conv_t_layer = Conv2DTranspose( filters=self.decoder_conv_t_filters[i], kernel_size=self.decoder_conv_t_kernel_size[i], strides=self.decoder_conv_t_strides[i], padding="same", name="decoder_conv_t_" + str(i), ) x = conv_t_layer(x) if i < self.n_layers_decoder - 1: x = LeakyReLU()(x) if self.use_batch_norm: x = BatchNormalization()(x) if self.use_dropout: x = Dropout(rate=0.25)(x) else: x = Activation("sigmoid")(x) decoder_output = x self.decoder = Model(decoder_input, decoder_output) ### THE FULL AUTOENCODER model_input = encoder_input model_output = self.decoder(encoder_output) self.model = Model(model_input, model_output) def compile(self, learning_rate): self.learning_rate = learning_rate optimizer = Adam(lr=learning_rate) def r_loss(y_true, y_pred): return K.mean(K.square(y_true - y_pred), axis=[1, 2, 3]) self.model.compile(optimizer=optimizer, loss=r_loss) def save(self, folder): if not os.path.exists(folder): os.makedirs(folder) os.makedirs(os.path.join(folder, "viz")) os.makedirs(os.path.join(folder, "weights")) os.makedirs(os.path.join(folder, "images")) with open(os.path.join(folder, "params.pkl"), "wb") as f: pickle.dump( [ self.input_dim, self.encoder_conv_filters, self.encoder_conv_kernel_size, self.encoder_conv_strides, self.decoder_conv_t_filters, self.decoder_conv_t_kernel_size, self.decoder_conv_t_strides, self.z_dim, self.use_batch_norm, self.use_dropout, ], f, ) self.plot_model(folder) def load_weights(self, filepath): self.model.load_weights(filepath) def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches=100, initial_epoch=0, lr_decay=1): custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self) lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) checkpoint2 = ModelCheckpoint(os.path.join(run_folder, "weights/weights.h5"), save_weights_only=True, verbose=0) callbacks_list = [checkpoint2, custom_callback, lr_sched] self.model.fit( x_train, x_train, batch_size=batch_size, shuffle=True, epochs=epochs, initial_epoch=initial_epoch, callbacks=callbacks_list, ) def plot_model(self, run_folder): plot_model( self.model, to_file=os.path.join(run_folder, "viz/model.png"), show_shapes=True, show_layer_names=True ) plot_model( self.encoder, to_file=os.path.join(run_folder, "viz/encoder.png"), show_shapes=True, show_layer_names=True ) plot_model( self.decoder, to_file=os.path.join(run_folder, "viz/decoder.png"), show_shapes=True, show_layer_names=True )
# https://github.com/davidADSP/GDL_code/blob/master/03_01_autoencoder_train.ipynb
import os # from utils.loaders import load_mnist # from models.AE import Autoencoder RUN_FOLDER = "ae_digits" if not os.path.exists(RUN_FOLDER): os.mkdir(RUN_FOLDER) os.mkdir(os.path.join(RUN_FOLDER, "viz")) os.mkdir(os.path.join(RUN_FOLDER, "images")) os.mkdir(os.path.join(RUN_FOLDER, "weights"))
AE = Autoencoder( input_dim=(28, 28, 1), encoder_conv_filters=[32, 64, 64, 64], encoder_conv_kernel_size=[3, 3, 3, 3], encoder_conv_strides=[1, 2, 2, 1], decoder_conv_t_filters=[64, 64, 32, 1], decoder_conv_t_kernel_size=[3, 3, 3, 3], decoder_conv_t_strides=[1, 2, 2, 1], z_dim=2, ) MODE = "build" #'load' # if MODE == "build": AE.save(RUN_FOLDER) else: AE.load_weights(os.path.join(RUN_FOLDER, "weights/weights.h5"))
LEARNING_RATE = 0.001 # 0.0005 BATCH_SIZE = 256 # 32 INITIAL_EPOCH = 0 AE.compile(LEARNING_RATE)
AE.encoder.summary()
Model: "functional_13" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= encoder_input (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ encoder_conv_0 (Conv2D) (None, 28, 28, 32) 320 _________________________________________________________________ leaky_re_lu_14 (LeakyReLU) (None, 28, 28, 32) 0 _________________________________________________________________ encoder_conv_1 (Conv2D) (None, 14, 14, 64) 18496 _________________________________________________________________ leaky_re_lu_15 (LeakyReLU) (None, 14, 14, 64) 0 _________________________________________________________________ encoder_conv_2 (Conv2D) (None, 7, 7, 64) 36928 _________________________________________________________________ leaky_re_lu_16 (LeakyReLU) (None, 7, 7, 64) 0 _________________________________________________________________ encoder_conv_3 (Conv2D) (None, 7, 7, 64) 36928 _________________________________________________________________ leaky_re_lu_17 (LeakyReLU) (None, 7, 7, 64) 0 _________________________________________________________________ flatten_2 (Flatten) (None, 3136) 0 _________________________________________________________________ encoder_output (Dense) (None, 2) 6274 ================================================================= Total params: 98,946 Trainable params: 98,946 Non-trainable params: 0 _________________________________________________________________
AE.decoder.summary()
Model: "functional_15" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= decoder_input (InputLayer) [(None, 2)] 0 _________________________________________________________________ dense_2 (Dense) (None, 3136) 9408 _________________________________________________________________ reshape_2 (Reshape) (None, 7, 7, 64) 0 _________________________________________________________________ decoder_conv_t_0 (Conv2DTran (None, 7, 7, 64) 36928 _________________________________________________________________ leaky_re_lu_18 (LeakyReLU) (None, 7, 7, 64) 0 _________________________________________________________________ decoder_conv_t_1 (Conv2DTran (None, 14, 14, 64) 36928 _________________________________________________________________ leaky_re_lu_19 (LeakyReLU) (None, 14, 14, 64) 0 _________________________________________________________________ decoder_conv_t_2 (Conv2DTran (None, 28, 28, 32) 18464 _________________________________________________________________ leaky_re_lu_20 (LeakyReLU) (None, 28, 28, 32) 0 _________________________________________________________________ decoder_conv_t_3 (Conv2DTran (None, 28, 28, 1) 289 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 1) 0 ================================================================= Total params: 102,017 Trainable params: 102,017 Non-trainable params: 0 _________________________________________________________________
x_train.shape
(60000, 28, 28, 1)
N = 20000 # N = 5000 EPOCHS = 30 AE.train( x_train[:N], batch_size=BATCH_SIZE, epochs=EPOCHS, run_folder=RUN_FOLDER, initial_epoch=INITIAL_EPOCH, print_every_n_batches=10, ) # def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches = 10, initial_epoch = 0, lr_decay = 1):
Epoch 1/30 79/79 [==============================] - 3s 34ms/step - loss: 0.1026 Epoch 2/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0596 Epoch 3/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0557 Epoch 4/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0538 Epoch 5/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0524 Epoch 6/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0510 Epoch 7/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0500 Epoch 8/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0491 Epoch 9/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0485 Epoch 10/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0477 Epoch 11/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0471 Epoch 12/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0465 Epoch 13/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0461 Epoch 14/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0456 Epoch 15/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0452 Epoch 16/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0449 Epoch 17/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0447 Epoch 18/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0443 Epoch 19/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0440 Epoch 20/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0438 Epoch 21/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0435 Epoch 22/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0434 Epoch 23/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0431 Epoch 24/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0429 Epoch 25/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0426 Epoch 26/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0427 Epoch 27/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0424 Epoch 28/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0421 Epoch 29/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0420 Epoch 30/30 79/79 [==============================] - 2s 26ms/step - loss: 0.0419
# https://github.com/davidADSP/GDL_code/blob/master/03_02_autoencoder_analysis.ipynb
!ls ae_digits/weights
weights.h5
AE = load_model(Autoencoder, RUN_FOLDER)

Reconstruction

n_to_show = 10 np.random.seed(42) example_idx = np.random.choice(range(len(x_test)), n_to_show) example_images = x_test[example_idx] z_points = AE.encoder.predict(example_images) reconst_images = AE.decoder.predict(z_points) fig = plt.figure(figsize=(15, 3)) fig.subplots_adjust(hspace=0.4, wspace=0.4) for i in range(n_to_show): img = example_images[i].squeeze() ax = fig.add_subplot(2, n_to_show, i + 1) ax.axis("off") ax.text(0.5, -0.35, str(np.round(z_points[i], 1)), fontsize=10, ha="center", transform=ax.transAxes) ax.imshow(img, cmap="gray_r") for i in range(n_to_show): img = reconst_images[i].squeeze() ax = fig.add_subplot(2, n_to_show, i + n_to_show + 1) ax.axis("off") ax.imshow(img, cmap="gray_r") plt.show()
Image in a Jupyter notebook

Generation

# Show encodings of random images n_to_show = 5000 grid_size = 15 figsize = 12 np.random.seed(42) example_idx = np.random.choice(range(len(x_test)), n_to_show) example_images = x_test[example_idx] example_labels = y_test[example_idx] z_points = AE.encoder.predict(example_images) min_x = min(z_points[:, 0]) max_x = max(z_points[:, 0]) min_y = min(z_points[:, 1]) max_y = max(z_points[:, 1]) plt.figure(figsize=(figsize, figsize)) plt.scatter(z_points[:, 0], z_points[:, 1], c="black", alpha=0.5, s=2) plt.show()
Image in a Jupyter notebook
# Generate from random points in latent space figsize = 5 plt.figure(figsize=(figsize, figsize)) plt.scatter(z_points[:, 0], z_points[:, 1], c="black", alpha=0.5, s=2) grid_size = 10 grid_depth = 3 figsize = 15 x = np.random.uniform(min_x, max_x, size=grid_size * grid_depth) y = np.random.uniform(min_y, max_y, size=grid_size * grid_depth) z_grid = np.array(list(zip(x, y))) reconst = AE.decoder.predict(z_grid) plt.scatter(z_grid[:, 0], z_grid[:, 1], c="red", alpha=1, s=20) plt.show() fig = plt.figure(figsize=(figsize, grid_depth)) fig.subplots_adjust(hspace=0.4, wspace=0.4) for i in range(grid_size * grid_depth): ax = fig.add_subplot(grid_depth, grid_size, i + 1) ax.axis("off") ax.text(0.5, -0.35, str(np.round(z_grid[i], 1)), fontsize=10, ha="center", transform=ax.transAxes) ax.imshow(reconst[i, :, :, 0], cmap="Greys")
Image in a Jupyter notebookImage in a Jupyter notebook
# Color code latent points n_to_show = 5000 grid_size = 15 figsize = 12 example_idx = np.random.choice(range(len(x_test)), n_to_show) example_images = x_test[example_idx] example_labels = y_test[example_idx] z_points = AE.encoder.predict(example_images) plt.figure(figsize=(figsize, figsize)) plt.scatter(z_points[:, 0], z_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=2) plt.colorbar() plt.show()
Image in a Jupyter notebook
n_to_show = 5000 grid_size = 20 figsize = 15 example_idx = np.random.choice(range(len(x_test)), n_to_show) example_images = x_test[example_idx] example_labels = y_test[example_idx] z_points = AE.encoder.predict(example_images) plt.figure(figsize=(figsize, figsize)) plt.scatter(z_points[:, 0], z_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=2) plt.colorbar() # x = norm.ppf(np.linspace(0.05, 0.95, 10)) # y = norm.ppf(np.linspace(0.05, 0.95, 10)) x = np.linspace(min(z_points[:, 0]), max(z_points[:, 0]), grid_size) y = np.linspace(max(z_points[:, 1]), min(z_points[:, 1]), grid_size) xv, yv = np.meshgrid(x, y) xv = xv.flatten() yv = yv.flatten() z_grid = np.array(list(zip(xv, yv))) reconst = AE.decoder.predict(z_grid) plt.scatter(z_grid[:, 0], z_grid[:, 1], c="black", alpha=1, s=5) # , cmap='rainbow' , c= example_labels plt.show() fig = plt.figure(figsize=(figsize, figsize)) fig.subplots_adjust(hspace=0.4, wspace=0.4) for i in range(grid_size**2): ax = fig.add_subplot(grid_size, grid_size, i + 1) ax.axis("off") ax.imshow(reconst[i, :, :, 0], cmap="Greys")
Image in a Jupyter notebookImage in a Jupyter notebook