Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/es-419/guide/checkpoint.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Entrenar puntos de verificación

La frase "Guardar un modelo de TensorFlow" normalmente significa una de dos cosas:

  1. Puntos de verificación, O

  2. SavedModel.

Los puntos de verificación capturan el valor exacto de todos los parámetros (tf.objetos tf.Variable) que usa un modelo. Los puntos de verificación no contienen ninguna descripción del cálculo definido por el modelo y, por lo tanto, solo suelen ser útiles cuando el código fuente que usará los valores de los parámetros guardados está disponible.

Por otro lado, el formato SavedModel incluye una descripción serializada del cálculo definido por el modelo además de los valores de los parámetros (punto de verificación). Los modelos en este formato son independientes del código fuente que creó el modelo. Por lo tanto, son adecuados para implementarse a través de TensorFlow Serving, TensorFlow Lite, TensorFlow.js o mediante programas en otros lenguajes de programación (las API de TensorFlow C, C++, Java, Go, Rust, C#, etc.).

Esta guía cubre las API para escribir y leer puntos de verificación.

Preparación

import tensorflow as tf
class Net(tf.keras.Model): """A simple linear model.""" def __init__(self): super(Net, self).__init__() self.l1 = tf.keras.layers.Dense(5) def call(self, x): return self.l1(x)
net = Net()

Guardar desde las API de entrenamiento tf.keras

Consulte la guía de tf.keras sobre cómo guardar y restaurar.

tf.keras.Model.save_weights guarda un punto de verificación de TensorFlow.

net.save_weights('easy_checkpoint')

Escribir puntos de verificación

El estado persistente de un modelo de TensorFlow se almacena en objetos tf.Variable. Estos se pueden construir directamente, pero a menudo se crean a través de una API de alto nivel como tf.keras.layers o tf.keras.Model.

La forma más fácil de gestionar variables es adjuntarlas a objetos de Python y luego hacer referencia a esos objetos.

Las subclases de tf.train.Checkpoint, tf.keras.layers.Layer y tf.keras.Model trazan automáticamente las variables asignadas a sus atributos. En el siguiente ejemplo, se construye un modelo lineal simple y luego se escriben puntos de verificación que contienen valores para todas las variables del modelo.

Puede guardar un punto de verificación del modelo fácilmente con Model.save_weights.

Puntos de verificación manuales

Preparación

Para ayudar a demostrar todas las características de tf.train.Checkpoint, defina un conjunto de datos de juguete y un paso de optimización:

def toy_dataset(): inputs = tf.range(10.)[:, None] labels = inputs * 5. + tf.range(5.)[None, :] return tf.data.Dataset.from_tensor_slices( dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer): """Trains `net` on `example` using `optimizer`.""" with tf.GradientTape() as tape: output = net(example['x']) loss = tf.reduce_mean(tf.abs(output - example['y'])) variables = net.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) return loss

Crear los objetos del punto de verificación

Use un objeto tf.train.Checkpoint para crear un punto de verificación de forma manual, donde los objetos de los que se quiera guardar un punto de verificación se establezcan como atributos en el objeto.

Un tf.train.CheckpointManager también puede resultar útil para gestionar varios puntos de verificación.

opt = tf.keras.optimizers.Adam(0.1) dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Entrenar y guardar puntos de verificación del modelo

En el siguiente ciclo de entrenamiento, se crea una instancia del modelo y de un optimizador, luego se reúnen en un objeto tf.train.Checkpoint. Se llama al paso de entrenamiento en un bucle en cada lote de datos y se escriben puntos de verificación en el disco de manera periódica.

def train_and_checkpoint(net, manager): ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: print("Restored from {}".format(manager.latest_checkpoint)) else: print("Initializing from scratch.") for _ in range(50): example = next(iterator) loss = train_step(net, example, opt) ckpt.step.assign_add(1) if int(ckpt.step) % 10 == 0: save_path = manager.save() print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path)) print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)

Recuperar y continuar el entrenamiento

Después del primer ciclo de entrenamiento, se puede aprobar un modelo y gestor nuevos, pero se puede continuar con el entrenamiento exactamente desde donde se dejó:

opt = tf.keras.optimizers.Adam(0.1) net = Net() dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3) train_and_checkpoint(net, manager)

El objeto tf.train.CheckpointManager elimina puntos de verificación antiguos. En el ejemplo anterior, está configurado para conservar solo los tres puntos de verificación más recientes.

print(manager.checkpoints) # List the three remaining checkpoints

Estas rutas, por ejemplo, './tf_ckpts/ckpt-10', no son archivos en el disco. En realidad, son prefijos para un archivo index y uno o más archivos de datos que contienen los valores de las variables. Estos prefijos se agrupan en un único archivo checkpoint ('./tf_ckpts/checkpoint') donde CheckpointManager guarda su estado.

!ls ./tf_ckpts

Mecánica de carga

TensorFlow une a las variables con los valores de los que se guardaron los puntos de verificación al recorrer un gráfico dirigido con bordes con nombre, comenzando desde el objeto que se está cargando. Los nombres de los bordes normalmente provienen de nombres de atributos en objetos, por ejemplo "l1" en self.l1 = tf.keras.layers.Dense(5). tf.train.Checkpoint usa los nombres de sus argumentos de palabras clave, como en el "step" en tf.train.Checkpoint(step=...).

El gráfico de dependencia del ejemplo anterior se ve así:

Visualización del gráfico de dependencia para el bucle de entrenamiento de ejemplo.

El optimizador es rojo, las variables regulares son azules y las variables de ranura del optimizador son naranja. Los otros nodos (por ejemplo, los que representan tf.train.Checkpoint) están en negro.

Las variables de ranura son parte del estado del optimizador, pero se crean para una variable específica. Por ejemplo, los bordes 'm' anteriores corresponden al impulso, que el optimizador Adam rastrea para cada variable. Las variables de ranura solo se guardan en un punto de verificación solo si se guardan la variable y el optimizador, por eso se muestran con bordes discontinuos.

Cuando se llama restore en un objeto tf.train.Checkpoint se ponen en cola las restauraciones solicitadas y se restauran los valores de las variables tan pronto como haya una ruta de unión desde el objeto Checkpoint. Por ejemplo, solo se puede cargar el sesgo del modelo que se definió anteriormente al reconstruir una ruta hacia él a través de la red y la capa.

to_restore = tf.Variable(tf.zeros([5])) print(to_restore.numpy()) # All zeros fake_layer = tf.train.Checkpoint(bias=to_restore) fake_net = tf.train.Checkpoint(l1=fake_layer) new_root = tf.train.Checkpoint(net=fake_net) status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/')) print(to_restore.numpy()) # This gets the restored value.

El gráfico de dependencia para estos objetos nuevos es un subgráfico mucho más pequeño del punto de verificación más grande que se escribió anteriormente. Incluye solo el sesgo y un contador de guardado que tf.train.Checkpoint usa para enumerar los puntos de verificación.

Visualización de un subgráfico para la variable de sesgo.

restore devuelve un objeto de estado, que tiene aserciones opcionales. Todos los objetos creados en el nuevo Checkpoint han sido restauradas, por eso status.assert_existing_objects_matched pasa.

status.assert_existing_objects_matched()

Hay muchos objetos en el punto de verificación que no se corresponden, incluido el núcleo de la capa y las variables del optimizador. status.assert_consumed solo pasa si el punto de verificación y el programa se corresponden exactamente, y se generaría una excepción en este caso.

Restauraciones diferidas

Los objetos Layer en TensorFlow pueden diferir la creación de variables hasta su primera llamada, cuando las formas de entrada están disponibles. Por ejemplo, la forma del núcleo de una capa Dense depende tanto de las formas de entrada como de la salida de la capa,. Por lo tanto, la forma de salida requerida como argumento del constructor no es información suficiente para crear la variable por sí sola. Dado que al llamar a una Layer también se lee el valor de la variable, se debe realizar una restauración entre la creación de la variable y su primer uso.

Para admitir este modismo, tf.train.Checkpoint pospone las restauraciones que aún no tienen una variable correspondiente.

deferred_restore = tf.Variable(tf.zeros([1, 5])) print(deferred_restore.numpy()) # Not restored; still zeros fake_layer.kernel = deferred_restore print(deferred_restore.numpy()) # Restored

Inspeccionar manualmente los puntos de verificación

tf.train.load_checkpoint devuelve un CheckpointReader que brinda acceso de nivel inferior al contenido del punto de verificación. Contiene asignaciones de la clave de cada variable, la forma y el dtype de cada variable en el punto de verificación. La clave de una variable es la ruta de su objeto, como en los gráficos anteriores.

Nota: No existe una estructura de nivel más superior que el punto de verificación. Solo conoce las rutas y los valores de las variables y no entiende de models, layers ni cómo están conectados.

reader = tf.train.load_checkpoint('./tf_ckpts/') shape_from_key = reader.get_variable_to_shape_map() dtype_from_key = reader.get_variable_to_dtype_map() sorted(shape_from_key.keys())

Entonces, si quiere saber el valor de net.l1.kernel, se puede obtener con el siguiente código:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE' print("Shape:", shape_from_key[key]) print("Dtype:", dtype_from_key[key].name)

También proporciona un método get_tensor que le permite inspeccionar el valor de una variable:

reader.get_tensor(key)

Seguimiento de objetos

Los puntos de verificación guardan y restauran los valores de los objetos tf.Variable al "hacer el seguimiento" de cualquier variable u objeto rastreable establecido en uno de sus atributos. Al ejecutar un guardado, las variables se recopilan de forma recursiva de todos los objetos rastreados accesibles.

Al igual que con las asignaciones directas de atributos como self.l1 = tf.keras.layers.Dense(5), la asignación de listas y diccionarios a atributos hará el seguimiento de su contenido.

save = tf.train.Checkpoint() save.listed = [tf.Variable(1.)] save.listed.append(tf.Variable(2.)) save.mapped = {'one': save.listed[0]} save.mapped['two'] = save.listed[1] save_path = save.save('./tf_list_example') restore = tf.train.Checkpoint() v2 = tf.Variable(0.) assert 0. == v2.numpy() # Not restored yet restore.mapped = {'two': v2} restore.restore(save_path) assert 2. == v2.numpy()

Es posible que observe objetos empaquetadores de listas y diccionarios. Estos empaquetadores son versiones que pueden giardarse como puntos de verificación de las estructuras de datos subyacentes. Al igual que la carga basada en atributos, estos empaquetadores restauran el valor de una variable tan pronto como se agrega al contenedor.

restore.listed = [] print(restore.listed) # ListWrapper([]) v1 = tf.Variable(0.) restore.listed.append(v1) # Restores v1, from restore() in the previous cell assert 1. == v1.numpy()

Los objetos rastreables incluyen tf.train.Checkpoint, tf.Module y sus subclases (por ejemplo keras.layers.Layer y keras.Model) y contenedores de Python reconocidos:

  • dict (y collections.OrderedDict)

  • list

  • tuple (y collections.namedtuple, typing.NamedTuple)

Entre los tipos de contenedores que no se admiten están:

  • collections.defaultdict

  • set

Todos los demás objetos de Python se ignoran, entre ellos:

  • int

  • string

  • float

Resumen

Los objetos de TensorFlow proporcionan un mecanismo automático simple para guardar y restaurar los valores de las variables que usan.