Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/es-419/tutorials/generative/dcgan.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Redes generativas adversarias convolucionales profundas

En este tutorial se enseña como generar imágenes de cifras escritas a mano con la red generativa adversaria convolucional profunda (DCGAN, por sus siglas en inglés). El código se escribe con la API secuencial de Keras con un bucle de entrenamiento tf.GradientTape.

¿Qué son las redes GAN?

Actualmente, las redes generativas adversarias (GAN, por sus siglas en inglés) son una de las ideas más interesantes en ciencias de la computación. Se entrenan dos modelos en simultáneo mediante un proceso adversario. Un generador ("el artista") aprende a crear imágenes que parecen reales, mientras que un discriminador ("el crítico de arte") aprende a diferenciar entre las imágenes reales y las falsas.

A diagram of a generator and discriminator

Durante el entrenamiento, el generador es cada vez más bueno en crear imágenes que parecen reales, mientras que el discriminador es cada vez más bueno en notar la diferencia. El proceso encuentra un equilibrio cuando el discriminador ya no puede diferenciar entre las imágenes reales y las falsas.

A second diagram of a generator and discriminator

En este cuaderno, se explica este proceso en un conjunto de datos MNIST. La siguiente animación muestra una serie de imágenes producidas por el generador durante un entrenamiento para 50 épocas. Al principio, las imágenes son ruido aleatorio y, con el tiempo, comienzan a lucir más como cifras escritas a mano.

sample output

Para más información sobre las redes GAN, vea el curso de introducción de aprendizaje profundo de MIT.

Preparación

import tensorflow as tf
tf.__version__
# To generate GIFs !pip install imageio !pip install git+https://github.com/tensorflow/docs
import glob import imageio import matplotlib.pyplot as plt import numpy as np import os import PIL from tensorflow.keras import layers import time from IPython import display

Cargar y preparar el conjunto de datos

Usará el conjunto de datos para entrenar el generador y el discriminador. El generador generará cifras escritas a mano que se parezcan a los datos de MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000 BATCH_SIZE = 256
# Batch and shuffle the data train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Crear los modelos

El generador y el discriminador se definen mediante el uso de la API secuencial de Keras.

El generador

El generador usa las capas tf.keras.layers.Conv2DTranspose (aumento de la resolución) para producir una imagen a partir de un valor de inicialización (ruido aleatorio). Comience con una capa Dense que tome este valor de inicialización como entrada, luego aumente la resolución varias veces hasta alcanzar el tamaño de imagen deseado de 28x28x1. Note la activación de tf.keras.layers.LeakyReLU para cada capa, excepto en la capa externa que usa tanh.

def make_generator_model(): model = tf.keras.Sequential() model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Reshape((7, 7, 256))) assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) assert model.output_shape == (None, 7, 7, 128) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) assert model.output_shape == (None, 14, 14, 64) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) assert model.output_shape == (None, 28, 28, 1) return model

Use el generador (sin entrenamiento aún) para crear una imagen.

generator = make_generator_model() noise = tf.random.normal([1, 100]) generated_image = generator(noise, training=False) plt.imshow(generated_image[0, :, :, 0], cmap='gray')

El discriminador

El discriminador es un clasificador de imágenes basadas en la red neuronal convolucional (CNN, por sus siglas en inglés).

def make_discriminator_model(): model = tf.keras.Sequential() model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1])) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Flatten()) model.add(layers.Dense(1)) return model

Use el discriminador (sin entrenar aún) para clasificar las imágenes generadas en reales o falsas. El modelo será entrenado para que tenga como salida valores positivos para las imágenes reales y valores negativos para las imágenes falsas.

discriminator = make_discriminator_model() decision = discriminator(generated_image) print (decision)

Definir la pérdida y los optimizadores

Defina las funciones de pérdida y los optimizadores para ambos modelos.

# This method returns a helper function to compute cross entropy loss cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Pérdida del discriminador

Este método cuantifica qué tan bueno es el discriminador para distinguir entre imágenes reales y falsas. Compara las predicciones del discriminador sobre las imágenes reales en un arreglo de números 1 y las predicciones del discriminador sobre las imágenes falsas (generadas) en un arreglo de números 0.

def discriminator_loss(real_output, fake_output): real_loss = cross_entropy(tf.ones_like(real_output), real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) total_loss = real_loss + fake_loss return total_loss

Pérdida del generador

La pérdida del generador cuantifica qué tan bueno es para engañar al discriminador. De forma intuitiva, si el generador funciona bien, el discriminador clasificará a las imágenes reales en falsas (o con el número 1). Ahora, compare las decisiones del discriminador sobre las imágenes generadas en un arreglo de números 1.

def generator_loss(fake_output): return cross_entropy(tf.ones_like(fake_output), fake_output)

Los optimizadores del discriminador y del generador son diferentes ya que se entrenarán dos redes por separado.

generator_optimizer = tf.keras.optimizers.Adam(1e-4) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Guardar puntos de verificación

En este cuaderno también se enseñará cómo guardar y recuperar modelos, lo cual es útil en caso de que se interrumpa una tarea de ejecución larga.

checkpoint_dir = './training_checkpoints' checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=generator, discriminator=discriminator)

Definir el bucle de entrenamiento

EPOCHS = 50 noise_dim = 100 num_examples_to_generate = 16 # You will reuse this seed overtime (so it's easier) # to visualize progress in the animated GIF) seed = tf.random.normal([num_examples_to_generate, noise_dim])

El bucle de entrenamiento comienza cuando el generador recibe un valor de inicialización aleatorio como entrada. Ese valor se usa para producir una imagen. Luego, se usa el discriminador para clasificar imágenes reales (extraídas del conjunto de datos de entrenamiento) e imágenes falsas (producidas por el generador). Se calcula la pérdida para cada uno de estos modelos y se usan los gradientes para actualizar el generador y el discriminador.

# Notice the use of `tf.function` # This annotation causes the function to be "compiled". @tf.function def train_step(images): noise = tf.random.normal([BATCH_SIZE, noise_dim]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs): for epoch in range(epochs): start = time.time() for image_batch in dataset: train_step(image_batch) # Produce images for the GIF as you go display.clear_output(wait=True) generate_and_save_images(generator, epoch + 1, seed) # Save the model every 15 epochs if (epoch + 1) % 15 == 0: checkpoint.save(file_prefix = checkpoint_prefix) print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) # Generate after the final epoch display.clear_output(wait=True) generate_and_save_images(generator, epochs, seed)

Generar y guardar imágenes

def generate_and_save_images(model, epoch, test_input): # Notice `training` is set to False. # This is so all layers run in inference mode (batchnorm). predictions = model(test_input, training=False) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i+1) plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') plt.axis('off') plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) plt.show()

Entrenar el modelo

Llame al método train() que se definió anteriormente para entrenar el generador y el discriminador de forma simultánea. Nota: entrenar las redes GAN puede ser complicado. Es importante que el generador y el discriminador no se superpongan (por ejemplo, que sean entrenados al mismo ritmo).

Al principio del entrenamiento, las imágenes generadas lucirán como ruido aleatorio. A medida que el entrenamiento avanza, las cifras generadas se verán cada vez más reales. Luego de 50 épocas, se verán como las cifras del MNIST. Esto puede tardar un minuto/época con la configuración predeterminada en Colab.

train(train_dataset, EPOCHS)

Restaurar el último punto de verificación

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

Crear un GIF

# Display a single image using the epoch number def display_image(epoch_no): return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

Use imageio para crear un gif animado para usar imágenes que se guardaron durante el entrenamiento.

anim_file = 'dcgan.gif' with imageio.get_writer(anim_file, mode='I') as writer: filenames = glob.glob('image*.png') filenames = sorted(filenames) for filename in filenames: image = imageio.imread(filename) writer.append_data(image) image = imageio.imread(filename) writer.append_data(image)
import tensorflow_docs.vis.embed as embed embed.embed_file(anim_file)

Próximos pasos

En este tutorial se mostró el código completo necesario para escribir y entrenar una red GAN. El siguiente paso que tome puede ser experimentar con diferentes conjuntos de datos, por ejemplo, el conjunto de datos de atributos de las caras de celebridades a gran escala (CelebA) disponible en Kaggle. Para más información sobre las redes GAN, vea el Tutorial de NIPS 2016: Redes generativas adversarias.