Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/ae_layerwise_fashion_tf.py
1192 views
1
# Greedy layerwise training of a 2 layer autoencoder (MLP) on Fashion MNIST
2
3
# Code is based on
4
# https://github.com/ageron/handson-ml2/blob/master/17_autoencoders_and_gans.ipynb
5
6
import superimport
7
8
import numpy as np
9
import matplotlib.pyplot as plt
10
11
import os
12
figdir = "../figures"
13
def save_fig(fname): plt.savefig(os.path.join(figdir, fname))
14
15
import tensorflow as tf
16
from tensorflow import keras
17
18
(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
19
X_train_full = X_train_full.astype(np.float32) / 255
20
X_test = X_test.astype(np.float32) / 255
21
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
22
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]
23
24
def rounded_accuracy(y_true, y_pred):
25
return keras.metrics.binary_accuracy(tf.round(y_true), tf.round(y_pred))
26
27
def plot_image(image):
28
plt.imshow(image, cmap="binary")
29
plt.axis("off")
30
31
32
def show_reconstructions(model, images=X_valid, n_images=5):
33
reconstructions = model.predict(images[:n_images])
34
plt.figure(figsize=(n_images * 1.5, 3))
35
for image_index in range(n_images):
36
plt.subplot(2, n_images, 1 + image_index)
37
plot_image(images[image_index])
38
plt.subplot(2, n_images, 1 + n_images + image_index)
39
plot_image(reconstructions[image_index])
40
41
def train_autoencoder(n_neurons, X_train, X_valid, loss, optimizer,
42
n_epochs=10, output_activation=None, metrics=None):
43
n_inputs = X_train.shape[-1]
44
encoder = keras.models.Sequential([
45
keras.layers.Dense(n_neurons, activation="selu", input_shape=[n_inputs])
46
])
47
decoder = keras.models.Sequential([
48
keras.layers.Dense(n_inputs, activation=output_activation),
49
])
50
autoencoder = keras.models.Sequential([encoder, decoder])
51
autoencoder.compile(optimizer, loss, metrics=metrics)
52
autoencoder.fit(X_train, X_train, epochs=n_epochs,
53
validation_data=[X_valid, X_valid])
54
return encoder, decoder, encoder(X_train), encoder(X_valid)
55
56
tf.random.set_seed(42)
57
np.random.seed(42)
58
59
K = keras.backend
60
X_train_flat = K.batch_flatten(X_train) # equivalent to .reshape(-1, 28 * 28)
61
X_valid_flat = K.batch_flatten(X_valid)
62
63
# Reconstruct binary image
64
enc1, dec1, X_train_enc1, X_valid_enc1 = train_autoencoder(
65
100, X_train_flat, X_valid_flat, "binary_crossentropy",
66
keras.optimizers.SGD(lr=1.5), output_activation="sigmoid",
67
metrics=[rounded_accuracy])
68
69
# Reconstruct real-valued codes
70
enc2, dec2, _, _ = train_autoencoder(
71
30, X_train_enc1, X_valid_enc1, "mse", keras.optimizers.SGD(lr=0.05),
72
output_activation="selu")
73
74
# Stack models, no fine tuning
75
stacked_ae = keras.models.Sequential([
76
keras.layers.Flatten(input_shape=[28, 28]),
77
enc1, enc2, dec2, dec1,
78
keras.layers.Reshape([28, 28])
79
])
80
81
show_reconstructions(stacked_ae)
82
plt.show()
83
84
# Fine tune stacked model end to end
85
stacked_ae.compile(loss="binary_crossentropy",
86
optimizer=keras.optimizers.SGD(lr=0.1), metrics=[rounded_accuracy])
87
history = stacked_ae.fit(X_train, X_train, epochs=10,
88
validation_data=[X_valid, X_valid])
89
show_reconstructions(stacked_ae)
90
plt.show()
91