Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Conditional-GAN-PyTorch-TensorFlow/TensorFlow/cgan_fashionmnist_tensorflow.py
3150 views
1
import imageio
2
import glob
3
import os
4
import time
5
import cv2
6
import tensorflow as tf
7
from tensorflow.keras import layers
8
from IPython import display
9
import matplotlib.pyplot as plt
10
import numpy as np
11
from tensorflow.keras import backend as K
12
from sklearn.manifold import TSNE
13
import matplotlib.pyplot as plt
14
from tensorflow import keras
15
from matplotlib import pyplot
16
from numpy import asarray
17
from numpy.random import randn
18
from numpy.random import randint
19
from numpy import linspace
20
from matplotlib import pyplot
21
22
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
23
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
24
x_test = x_test.astype('float32')
25
x_train = (x_train / 127.5) - 1
26
# Batch and shuffle the data
27
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).\
28
shuffle(60000).batch(128)
29
30
plt.figure(figsize=(10, 10))
31
for images,_ in train_dataset.take(1):
32
for i in range(100):
33
ax = plt.subplot(10, 10, i + 1)
34
plt.imshow(images[i,:,:,0].numpy().astype("uint8"), cmap='gray')
35
plt.axis("off")
36
37
BATCH_SIZE=128
38
latent_dim = 100
39
40
# label input
41
con_label = layers.Input(shape=(1,))
42
43
# image generator input
44
latent_vector = layers.Input(shape=(100,))
45
46
def label_conditioned_gen(n_classes=10, embedding_dim=100):
47
# embedding for categorical input
48
label_embedding = layers.Embedding(n_classes, embedding_dim)(con_label)
49
# linear multiplication
50
n_nodes = 7 * 7
51
label_dense = layers.Dense(n_nodes)(label_embedding)
52
# reshape to additional channel
53
label_reshape_layer = layers.Reshape((7, 7, 1))(label_dense)
54
return label_reshape_layer
55
56
def latent_gen(latent_dim=100):
57
# image generator input
58
in_lat = layers.Input(shape=(latent_dim,))
59
n_nodes = 128 * 7 * 7
60
latent_dense = layers.Dense(n_nodes)(latent_vector)
61
latent_dense = layers.LeakyReLU(alpha=0.2)(latent_dense)
62
latent_reshape = layers.Reshape((7, 7, 128))(latent_dense)
63
return latent_reshape
64
65
def con_generator():
66
latent_vector_output = label_conditioned_gen()
67
label_output = latent_gen()
68
# merge image gen and label input
69
merge = layers.Concatenate()([latent_vector_output, label_output])
70
# upsample to 14x14
71
x = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)
72
x = layers.LeakyReLU(alpha=0.2)(x)
73
# upsample to 28x28
74
x = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(x)
75
x = layers.LeakyReLU(alpha=0.2)(x)
76
# output
77
out_layer = layers.Conv2D(1, (7,7), activation='tanh', padding='same')(x)
78
# define model
79
model = tf.keras.Model([latent_vector, con_label], out_layer)
80
return model
81
82
conditional_gen = con_generator()
83
84
conditional_gen.summary()
85
86
def label_condition_disc(in_shape=(28, 28, 1), n_classes=10, embedding_dim=100):
87
# label input
88
con_label = layers.Input(shape=(1,))
89
# embedding for categorical input
90
label_embedding = layers.Embedding(n_classes, embedding_dim)(con_label)
91
# scale up to image dimensions with linear activation
92
nodes = in_shape[0] * in_shape[1] * in_shape[2]
93
label_dense = layers.Dense(nodes)(label_embedding)
94
# reshape to additional channel
95
label_reshape_layer = layers.Reshape((in_shape[0], in_shape[1], 1))(label_dense)
96
# image input
97
return con_label, label_reshape_layer
98
99
100
def image_disc(in_shape=(28,28, 1)):
101
inp_image = layers.Input(shape=in_shape)
102
return inp_image
103
104
def con_discriminator():
105
con_label, label_condition_output = label_condition_disc()
106
inp_image_output = image_disc()
107
# concat label as a channel
108
merge = layers.Concatenate()([inp_image_output, label_condition_output])
109
# downsample
110
x = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)
111
x = layers.LeakyReLU(alpha=0.2)(x)
112
# downsample
113
x = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(x)
114
x = layers.LeakyReLU(alpha=0.2)(x)
115
# flatten feature maps
116
flattened_out = layers.Flatten()(x)
117
# dropout
118
dropout = layers.Dropout(0.4)(flattened_out)
119
# output
120
dense_out = layers.Dense(1, activation='sigmoid')(dropout)
121
# define model
122
model = tf.keras.Model([inp_image_output, con_label], dense_out)
123
return model
124
125
conditional_discriminator = con_discriminator()
126
127
conditional_discriminator.summary()
128
129
binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()
130
131
def generator_loss(label, fake_output):
132
gen_loss = binary_cross_entropy(label, fake_output)
133
#print(gen_loss)
134
return gen_loss
135
136
def discriminator_loss(label, output):
137
disc_loss = binary_cross_entropy(label, output)
138
#print(total_loss)
139
return disc_loss
140
141
learning_rate = 0.0002
142
generator_optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999 )
143
discriminator_optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999 )
144
145
num_examples_to_generate = 25
146
latent_dim = 100
147
# We will reuse this seed overtime to visualize progress
148
seed = tf.random.normal([num_examples_to_generate, latent_dim])
149
150
print(seed.dtype)
151
152
print(conditional_gen.input)
153
154
# Notice the use of `tf.function`
155
# This annotation causes the function to be "compiled".
156
@tf.function
157
def train_step(images,target):
158
# noise vector sampled from normal distribution
159
noise = tf.random.normal([target.shape[0], latent_dim])
160
# Train Discriminator with real labels
161
with tf.GradientTape() as disc_tape1:
162
generated_images = conditional_gen([noise,target], training=True)
163
164
165
real_output = conditional_discriminator([images,target], training=True)
166
real_targets = tf.ones_like(real_output)
167
disc_loss1 = discriminator_loss(real_targets, real_output)
168
169
# gradient calculation for discriminator for real labels
170
gradients_of_disc1 = disc_tape1.gradient(disc_loss1, conditional_discriminator.trainable_variables)
171
172
# parameters optimization for discriminator for real labels
173
discriminator_optimizer.apply_gradients(zip(gradients_of_disc1,\
174
conditional_discriminator.trainable_variables))
175
176
# Train Discriminator with fake labels
177
with tf.GradientTape() as disc_tape2:
178
fake_output = conditional_discriminator([generated_images,target], training=True)
179
fake_targets = tf.zeros_like(fake_output)
180
disc_loss2 = discriminator_loss(fake_targets, fake_output)
181
# gradient calculation for discriminator for fake labels
182
gradients_of_disc2 = disc_tape2.gradient(disc_loss2, conditional_discriminator.trainable_variables)
183
184
185
# parameters optimization for discriminator for fake labels
186
discriminator_optimizer.apply_gradients(zip(gradients_of_disc2,\
187
conditional_discriminator.trainable_variables))
188
189
# Train Generator with real labels
190
with tf.GradientTape() as gen_tape:
191
generated_images = conditional_gen([noise,target], training=True)
192
fake_output = conditional_discriminator([generated_images,target], training=True)
193
real_targets = tf.ones_like(fake_output)
194
gen_loss = generator_loss(real_targets, fake_output)
195
196
# gradient calculation for generator for real labels
197
gradients_of_gen = gen_tape.gradient(gen_loss, conditional_gen.trainable_variables)
198
199
# parameters optimization for generator for real labels
200
generator_optimizer.apply_gradients(zip(gradients_of_gen,\
201
conditional_gen.trainable_variables))
202
203
def train(dataset, epochs):
204
for epoch in range(epochs):
205
start = time.time()
206
i = 0
207
D_loss_list, G_loss_list = [], []
208
for image_batch,target in dataset:
209
i += 1
210
train_step(image_batch,target)
211
print(epoch)
212
display.clear_output(wait=True)
213
generate_and_save_images(conditional_gen,
214
epoch + 1,
215
seed)
216
217
# # Save the model every 15 epochs
218
# if (epoch + 1) % 15 == 0:
219
# checkpoint.save(file_prefix = checkpoint_prefix)
220
221
conditional_gen.save_weights('fashion/training_weights/gen_'+ str(epoch)+'.h5')
222
conditional_discriminator.save_weights('fashion/training_weights/disc_'+ str(epoch)+'.h5')
223
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
224
225
# Generate after the final epoch
226
display.clear_output(wait=True)
227
generate_and_save_images(conditional_gen,
228
epochs,
229
seed)
230
231
def label_gen(n_classes=10):
232
lab = tf.random.uniform((1,), minval=0, maxval=10, dtype=tf.dtypes.int32, seed=None, name=None)
233
return tf.repeat(lab, [25], axis=None, name=None)
234
235
# Create dictionary of target classes
236
label_dict = {
237
0: 'T-shirt/top',
238
1: 'Trouser',
239
2: 'Pullover',
240
3: 'Dress',
241
4: 'Coat',
242
5: 'Sandal',
243
6: 'Shirt',
244
7: 'Sneaker',
245
8: 'Bag',
246
9: 'Ankle boot',
247
}
248
249
def generate_and_save_images(model, epoch, test_input):
250
# Notice `training` is set to False.
251
# This is so all layers run in inference mode (batchnorm).
252
labels = label_gen()
253
predictions = model([test_input, labels], training=False)
254
print(predictions.shape)
255
fig = plt.figure(figsize=(4,4))
256
257
print("Generated Images are Conditioned on Label:", label_dict[np.array(labels)[0]])
258
for i in range(predictions.shape[0]):
259
pred = (predictions[i, :, :, 0] + 1) * 127.5
260
pred = np.array(pred)
261
plt.subplot(5, 5, i+1)
262
plt.imshow(pred.astype(np.uint8), cmap='gray')
263
plt.axis('off')
264
265
plt.savefig('fashion/images/image_at_epoch_{:d}.png'.format(epoch))
266
plt.show()
267
268
train(train_dataset, 2)
269
270
conditional_gen.load_weights('fashion/training_weights/gen_1.h5')
271
272
# example of interpolating between generated faces
273
274
fig = plt.figure(figsize=(10,10))
275
# generate points in latent space as input for the generator
276
def generate_latent_points(latent_dim, n_samples, n_classes=10):
277
# generate points in the latent space
278
x_input = randn(latent_dim * n_samples)
279
# reshape into a batch of inputs for the network
280
z_input = x_input.reshape(n_samples, latent_dim)
281
return z_input
282
283
# uniform interpolation between two points in latent space
284
def interpolate_points(p1, p2, n_steps=10):
285
# interpolate ratios between the points
286
ratios = linspace(0, 1, num=n_steps)
287
# linear interpolate vectors
288
vectors = list()
289
for ratio in ratios:
290
v = (1.0 - ratio) * p1 + ratio * p2
291
vectors.append(v)
292
return asarray(vectors)
293
294
295
# load model
296
pts = generate_latent_points(100, 2)
297
# interpolate points in latent space
298
interpolated = interpolate_points(pts[0], pts[1])
299
# generate images
300
from matplotlib import gridspec
301
302
output = None
303
for label in range(10):
304
labels = tf.ones(10) * label
305
predictions = conditional_gen([interpolated, labels], training=False)
306
if output is None:
307
output = predictions
308
else:
309
output = np.concatenate((output,predictions))
310
311
k = 0
312
nrow = 10
313
ncol = 10
314
fig = plt.figure(figsize=(15,15))
315
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
316
wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845)
317
318
319
for i in range(10):
320
for j in range(10):
321
pred = (output[k, :, :, :] + 1 ) * 127.5
322
ax= plt.subplot(gs[i,j])
323
pred = np.array(pred)
324
ax.imshow(pred.astype(np.uint8), cmap='gray')
325
ax.set_xticklabels([])
326
ax.set_yticklabels([])
327
ax.axis('off')
328
k += 1
329
330
331
plt.savefig('result.png', dpi=300)
332
plt.show()
333
334
print(pred.shape)
335