Path: blob/master/deprecated/scripts/ae_layerwise_fashion_tf.py
1192 views
# Greedy layerwise training of a 2 layer autoencoder (MLP) on Fashion MNIST12# Code is based on3# https://github.com/ageron/handson-ml2/blob/master/17_autoencoders_and_gans.ipynb45import superimport67import numpy as np8import matplotlib.pyplot as plt910import os11figdir = "../figures"12def save_fig(fname): plt.savefig(os.path.join(figdir, fname))1314import tensorflow as tf15from tensorflow import keras1617(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()18X_train_full = X_train_full.astype(np.float32) / 25519X_test = X_test.astype(np.float32) / 25520X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]21y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]2223def rounded_accuracy(y_true, y_pred):24return keras.metrics.binary_accuracy(tf.round(y_true), tf.round(y_pred))2526def plot_image(image):27plt.imshow(image, cmap="binary")28plt.axis("off")293031def show_reconstructions(model, images=X_valid, n_images=5):32reconstructions = model.predict(images[:n_images])33plt.figure(figsize=(n_images * 1.5, 3))34for image_index in range(n_images):35plt.subplot(2, n_images, 1 + image_index)36plot_image(images[image_index])37plt.subplot(2, n_images, 1 + n_images + image_index)38plot_image(reconstructions[image_index])3940def train_autoencoder(n_neurons, X_train, X_valid, loss, optimizer,41n_epochs=10, output_activation=None, metrics=None):42n_inputs = X_train.shape[-1]43encoder = keras.models.Sequential([44keras.layers.Dense(n_neurons, activation="selu", input_shape=[n_inputs])45])46decoder = keras.models.Sequential([47keras.layers.Dense(n_inputs, activation=output_activation),48])49autoencoder = keras.models.Sequential([encoder, decoder])50autoencoder.compile(optimizer, loss, metrics=metrics)51autoencoder.fit(X_train, X_train, epochs=n_epochs,52validation_data=[X_valid, X_valid])53return encoder, decoder, encoder(X_train), encoder(X_valid)5455tf.random.set_seed(42)56np.random.seed(42)5758K = keras.backend59X_train_flat = K.batch_flatten(X_train) # equivalent to .reshape(-1, 28 * 28)60X_valid_flat = K.batch_flatten(X_valid)6162# Reconstruct binary image63enc1, dec1, X_train_enc1, X_valid_enc1 = train_autoencoder(64100, X_train_flat, X_valid_flat, "binary_crossentropy",65keras.optimizers.SGD(lr=1.5), output_activation="sigmoid",66metrics=[rounded_accuracy])6768# Reconstruct real-valued codes69enc2, dec2, _, _ = train_autoencoder(7030, X_train_enc1, X_valid_enc1, "mse", keras.optimizers.SGD(lr=0.05),71output_activation="selu")7273# Stack models, no fine tuning74stacked_ae = keras.models.Sequential([75keras.layers.Flatten(input_shape=[28, 28]),76enc1, enc2, dec2, dec1,77keras.layers.Reshape([28, 28])78])7980show_reconstructions(stacked_ae)81plt.show()8283# Fine tune stacked model end to end84stacked_ae.compile(loss="binary_crossentropy",85optimizer=keras.optimizers.SGD(lr=0.1), metrics=[rounded_accuracy])86history = stacked_ae.fit(X_train, X_train, epochs=10,87validation_data=[X_valid, X_valid])88show_reconstructions(stacked_ae)89plt.show()9091