Path: blob/master/Conditional-GAN-PyTorch-TensorFlow/TensorFlow/cgan_fashionmnist_tensorflow.py
3150 views
import imageio1import glob2import os3import time4import cv25import tensorflow as tf6from tensorflow.keras import layers7from IPython import display8import matplotlib.pyplot as plt9import numpy as np10from tensorflow.keras import backend as K11from sklearn.manifold import TSNE12import matplotlib.pyplot as plt13from tensorflow import keras14from matplotlib import pyplot15from numpy import asarray16from numpy.random import randn17from numpy.random import randint18from numpy import linspace19from matplotlib import pyplot2021(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()22x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')23x_test = x_test.astype('float32')24x_train = (x_train / 127.5) - 125# Batch and shuffle the data26train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).\27shuffle(60000).batch(128)2829plt.figure(figsize=(10, 10))30for images,_ in train_dataset.take(1):31for i in range(100):32ax = plt.subplot(10, 10, i + 1)33plt.imshow(images[i,:,:,0].numpy().astype("uint8"), cmap='gray')34plt.axis("off")3536BATCH_SIZE=12837latent_dim = 1003839# label input40con_label = layers.Input(shape=(1,))4142# image generator input43latent_vector = layers.Input(shape=(100,))4445def label_conditioned_gen(n_classes=10, embedding_dim=100):46# embedding for categorical input47label_embedding = layers.Embedding(n_classes, embedding_dim)(con_label)48# linear multiplication49n_nodes = 7 * 750label_dense = layers.Dense(n_nodes)(label_embedding)51# reshape to additional channel52label_reshape_layer = layers.Reshape((7, 7, 1))(label_dense)53return label_reshape_layer5455def latent_gen(latent_dim=100):56# image generator input57in_lat = layers.Input(shape=(latent_dim,))58n_nodes = 128 * 7 * 759latent_dense = layers.Dense(n_nodes)(latent_vector)60latent_dense = layers.LeakyReLU(alpha=0.2)(latent_dense)61latent_reshape = layers.Reshape((7, 7, 128))(latent_dense)62return latent_reshape6364def con_generator():65latent_vector_output = label_conditioned_gen()66label_output = latent_gen()67# merge image gen and label input68merge = layers.Concatenate()([latent_vector_output, label_output])69# upsample to 14x1470x = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)71x = layers.LeakyReLU(alpha=0.2)(x)72# upsample to 28x2873x = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(x)74x = layers.LeakyReLU(alpha=0.2)(x)75# output76out_layer = layers.Conv2D(1, (7,7), activation='tanh', padding='same')(x)77# define model78model = tf.keras.Model([latent_vector, con_label], out_layer)79return model8081conditional_gen = con_generator()8283conditional_gen.summary()8485def label_condition_disc(in_shape=(28, 28, 1), n_classes=10, embedding_dim=100):86# label input87con_label = layers.Input(shape=(1,))88# embedding for categorical input89label_embedding = layers.Embedding(n_classes, embedding_dim)(con_label)90# scale up to image dimensions with linear activation91nodes = in_shape[0] * in_shape[1] * in_shape[2]92label_dense = layers.Dense(nodes)(label_embedding)93# reshape to additional channel94label_reshape_layer = layers.Reshape((in_shape[0], in_shape[1], 1))(label_dense)95# image input96return con_label, label_reshape_layer979899def image_disc(in_shape=(28,28, 1)):100inp_image = layers.Input(shape=in_shape)101return inp_image102103def con_discriminator():104con_label, label_condition_output = label_condition_disc()105inp_image_output = image_disc()106# concat label as a channel107merge = layers.Concatenate()([inp_image_output, label_condition_output])108# downsample109x = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)110x = layers.LeakyReLU(alpha=0.2)(x)111# downsample112x = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(x)113x = layers.LeakyReLU(alpha=0.2)(x)114# flatten feature maps115flattened_out = layers.Flatten()(x)116# dropout117dropout = layers.Dropout(0.4)(flattened_out)118# output119dense_out = layers.Dense(1, activation='sigmoid')(dropout)120# define model121model = tf.keras.Model([inp_image_output, con_label], dense_out)122return model123124conditional_discriminator = con_discriminator()125126conditional_discriminator.summary()127128binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()129130def generator_loss(label, fake_output):131gen_loss = binary_cross_entropy(label, fake_output)132#print(gen_loss)133return gen_loss134135def discriminator_loss(label, output):136disc_loss = binary_cross_entropy(label, output)137#print(total_loss)138return disc_loss139140learning_rate = 0.0002141generator_optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999 )142discriminator_optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999 )143144num_examples_to_generate = 25145latent_dim = 100146# We will reuse this seed overtime to visualize progress147seed = tf.random.normal([num_examples_to_generate, latent_dim])148149print(seed.dtype)150151print(conditional_gen.input)152153# Notice the use of `tf.function`154# This annotation causes the function to be "compiled".155@tf.function156def train_step(images,target):157# noise vector sampled from normal distribution158noise = tf.random.normal([target.shape[0], latent_dim])159# Train Discriminator with real labels160with tf.GradientTape() as disc_tape1:161generated_images = conditional_gen([noise,target], training=True)162163164real_output = conditional_discriminator([images,target], training=True)165real_targets = tf.ones_like(real_output)166disc_loss1 = discriminator_loss(real_targets, real_output)167168# gradient calculation for discriminator for real labels169gradients_of_disc1 = disc_tape1.gradient(disc_loss1, conditional_discriminator.trainable_variables)170171# parameters optimization for discriminator for real labels172discriminator_optimizer.apply_gradients(zip(gradients_of_disc1,\173conditional_discriminator.trainable_variables))174175# Train Discriminator with fake labels176with tf.GradientTape() as disc_tape2:177fake_output = conditional_discriminator([generated_images,target], training=True)178fake_targets = tf.zeros_like(fake_output)179disc_loss2 = discriminator_loss(fake_targets, fake_output)180# gradient calculation for discriminator for fake labels181gradients_of_disc2 = disc_tape2.gradient(disc_loss2, conditional_discriminator.trainable_variables)182183184# parameters optimization for discriminator for fake labels185discriminator_optimizer.apply_gradients(zip(gradients_of_disc2,\186conditional_discriminator.trainable_variables))187188# Train Generator with real labels189with tf.GradientTape() as gen_tape:190generated_images = conditional_gen([noise,target], training=True)191fake_output = conditional_discriminator([generated_images,target], training=True)192real_targets = tf.ones_like(fake_output)193gen_loss = generator_loss(real_targets, fake_output)194195# gradient calculation for generator for real labels196gradients_of_gen = gen_tape.gradient(gen_loss, conditional_gen.trainable_variables)197198# parameters optimization for generator for real labels199generator_optimizer.apply_gradients(zip(gradients_of_gen,\200conditional_gen.trainable_variables))201202def train(dataset, epochs):203for epoch in range(epochs):204start = time.time()205i = 0206D_loss_list, G_loss_list = [], []207for image_batch,target in dataset:208i += 1209train_step(image_batch,target)210print(epoch)211display.clear_output(wait=True)212generate_and_save_images(conditional_gen,213epoch + 1,214seed)215216# # Save the model every 15 epochs217# if (epoch + 1) % 15 == 0:218# checkpoint.save(file_prefix = checkpoint_prefix)219220conditional_gen.save_weights('fashion/training_weights/gen_'+ str(epoch)+'.h5')221conditional_discriminator.save_weights('fashion/training_weights/disc_'+ str(epoch)+'.h5')222print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))223224# Generate after the final epoch225display.clear_output(wait=True)226generate_and_save_images(conditional_gen,227epochs,228seed)229230def label_gen(n_classes=10):231lab = tf.random.uniform((1,), minval=0, maxval=10, dtype=tf.dtypes.int32, seed=None, name=None)232return tf.repeat(lab, [25], axis=None, name=None)233234# Create dictionary of target classes235label_dict = {2360: 'T-shirt/top',2371: 'Trouser',2382: 'Pullover',2393: 'Dress',2404: 'Coat',2415: 'Sandal',2426: 'Shirt',2437: 'Sneaker',2448: 'Bag',2459: 'Ankle boot',246}247248def generate_and_save_images(model, epoch, test_input):249# Notice `training` is set to False.250# This is so all layers run in inference mode (batchnorm).251labels = label_gen()252predictions = model([test_input, labels], training=False)253print(predictions.shape)254fig = plt.figure(figsize=(4,4))255256print("Generated Images are Conditioned on Label:", label_dict[np.array(labels)[0]])257for i in range(predictions.shape[0]):258pred = (predictions[i, :, :, 0] + 1) * 127.5259pred = np.array(pred)260plt.subplot(5, 5, i+1)261plt.imshow(pred.astype(np.uint8), cmap='gray')262plt.axis('off')263264plt.savefig('fashion/images/image_at_epoch_{:d}.png'.format(epoch))265plt.show()266267train(train_dataset, 2)268269conditional_gen.load_weights('fashion/training_weights/gen_1.h5')270271# example of interpolating between generated faces272273fig = plt.figure(figsize=(10,10))274# generate points in latent space as input for the generator275def generate_latent_points(latent_dim, n_samples, n_classes=10):276# generate points in the latent space277x_input = randn(latent_dim * n_samples)278# reshape into a batch of inputs for the network279z_input = x_input.reshape(n_samples, latent_dim)280return z_input281282# uniform interpolation between two points in latent space283def interpolate_points(p1, p2, n_steps=10):284# interpolate ratios between the points285ratios = linspace(0, 1, num=n_steps)286# linear interpolate vectors287vectors = list()288for ratio in ratios:289v = (1.0 - ratio) * p1 + ratio * p2290vectors.append(v)291return asarray(vectors)292293294# load model295pts = generate_latent_points(100, 2)296# interpolate points in latent space297interpolated = interpolate_points(pts[0], pts[1])298# generate images299from matplotlib import gridspec300301output = None302for label in range(10):303labels = tf.ones(10) * label304predictions = conditional_gen([interpolated, labels], training=False)305if output is None:306output = predictions307else:308output = np.concatenate((output,predictions))309310k = 0311nrow = 10312ncol = 10313fig = plt.figure(figsize=(15,15))314gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],315wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845)316317318for i in range(10):319for j in range(10):320pred = (output[k, :, :, :] + 1 ) * 127.5321ax= plt.subplot(gs[i,j])322pred = np.array(pred)323ax.imshow(pred.astype(np.uint8), cmap='gray')324ax.set_xticklabels([])325ax.set_yticklabels([])326ax.axis('off')327k += 1328329330plt.savefig('result.png', dpi=300)331plt.show()332333print(pred.shape)334335