Path: blob/master/site/es-419/guide/checkpoint.ipynb
25115 views
Copyright 2018 The TensorFlow Authors.
Entrenar puntos de verificación
La frase "Guardar un modelo de TensorFlow" normalmente significa una de dos cosas:
Puntos de verificación, O
SavedModel.
Los puntos de verificación capturan el valor exacto de todos los parámetros (tf.objetos tf.Variable
) que usa un modelo. Los puntos de verificación no contienen ninguna descripción del cálculo definido por el modelo y, por lo tanto, solo suelen ser útiles cuando el código fuente que usará los valores de los parámetros guardados está disponible.
Por otro lado, el formato SavedModel incluye una descripción serializada del cálculo definido por el modelo además de los valores de los parámetros (punto de verificación). Los modelos en este formato son independientes del código fuente que creó el modelo. Por lo tanto, son adecuados para implementarse a través de TensorFlow Serving, TensorFlow Lite, TensorFlow.js o mediante programas en otros lenguajes de programación (las API de TensorFlow C, C++, Java, Go, Rust, C#, etc.).
Esta guía cubre las API para escribir y leer puntos de verificación.
Preparación
Guardar desde las API de entrenamiento tf.keras
Consulte la guía de tf.keras
sobre cómo guardar y restaurar.
tf.keras.Model.save_weights
guarda un punto de verificación de TensorFlow.
Escribir puntos de verificación
El estado persistente de un modelo de TensorFlow se almacena en objetos tf.Variable
. Estos se pueden construir directamente, pero a menudo se crean a través de una API de alto nivel como tf.keras.layers
o tf.keras.Model
.
La forma más fácil de gestionar variables es adjuntarlas a objetos de Python y luego hacer referencia a esos objetos.
Las subclases de tf.train.Checkpoint
, tf.keras.layers.Layer
y tf.keras.Model
trazan automáticamente las variables asignadas a sus atributos. En el siguiente ejemplo, se construye un modelo lineal simple y luego se escriben puntos de verificación que contienen valores para todas las variables del modelo.
Puede guardar un punto de verificación del modelo fácilmente con Model.save_weights
.
Puntos de verificación manuales
Preparación
Para ayudar a demostrar todas las características de tf.train.Checkpoint
, defina un conjunto de datos de juguete y un paso de optimización:
Crear los objetos del punto de verificación
Use un objeto tf.train.Checkpoint
para crear un punto de verificación de forma manual, donde los objetos de los que se quiera guardar un punto de verificación se establezcan como atributos en el objeto.
Un tf.train.CheckpointManager
también puede resultar útil para gestionar varios puntos de verificación.
Entrenar y guardar puntos de verificación del modelo
En el siguiente ciclo de entrenamiento, se crea una instancia del modelo y de un optimizador, luego se reúnen en un objeto tf.train.Checkpoint
. Se llama al paso de entrenamiento en un bucle en cada lote de datos y se escriben puntos de verificación en el disco de manera periódica.
Recuperar y continuar el entrenamiento
Después del primer ciclo de entrenamiento, se puede aprobar un modelo y gestor nuevos, pero se puede continuar con el entrenamiento exactamente desde donde se dejó:
El objeto tf.train.CheckpointManager
elimina puntos de verificación antiguos. En el ejemplo anterior, está configurado para conservar solo los tres puntos de verificación más recientes.
Estas rutas, por ejemplo, './tf_ckpts/ckpt-10'
, no son archivos en el disco. En realidad, son prefijos para un archivo index
y uno o más archivos de datos que contienen los valores de las variables. Estos prefijos se agrupan en un único archivo checkpoint
('./tf_ckpts/checkpoint'
) donde CheckpointManager
guarda su estado.
Mecánica de carga
TensorFlow une a las variables con los valores de los que se guardaron los puntos de verificación al recorrer un gráfico dirigido con bordes con nombre, comenzando desde el objeto que se está cargando. Los nombres de los bordes normalmente provienen de nombres de atributos en objetos, por ejemplo "l1"
en self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
usa los nombres de sus argumentos de palabras clave, como en el "step"
en tf.train.Checkpoint(step=...)
.
El gráfico de dependencia del ejemplo anterior se ve así:
El optimizador es rojo, las variables regulares son azules y las variables de ranura del optimizador son naranja. Los otros nodos (por ejemplo, los que representan tf.train.Checkpoint
) están en negro.
Las variables de ranura son parte del estado del optimizador, pero se crean para una variable específica. Por ejemplo, los bordes 'm'
anteriores corresponden al impulso, que el optimizador Adam rastrea para cada variable. Las variables de ranura solo se guardan en un punto de verificación solo si se guardan la variable y el optimizador, por eso se muestran con bordes discontinuos.
Cuando se llama restore
en un objeto tf.train.Checkpoint
se ponen en cola las restauraciones solicitadas y se restauran los valores de las variables tan pronto como haya una ruta de unión desde el objeto Checkpoint
. Por ejemplo, solo se puede cargar el sesgo del modelo que se definió anteriormente al reconstruir una ruta hacia él a través de la red y la capa.
El gráfico de dependencia para estos objetos nuevos es un subgráfico mucho más pequeño del punto de verificación más grande que se escribió anteriormente. Incluye solo el sesgo y un contador de guardado que tf.train.Checkpoint
usa para enumerar los puntos de verificación.
restore
devuelve un objeto de estado, que tiene aserciones opcionales. Todos los objetos creados en el nuevo Checkpoint
han sido restauradas, por eso status.assert_existing_objects_matched
pasa.
Hay muchos objetos en el punto de verificación que no se corresponden, incluido el núcleo de la capa y las variables del optimizador. status.assert_consumed
solo pasa si el punto de verificación y el programa se corresponden exactamente, y se generaría una excepción en este caso.
Restauraciones diferidas
Los objetos Layer
en TensorFlow pueden diferir la creación de variables hasta su primera llamada, cuando las formas de entrada están disponibles. Por ejemplo, la forma del núcleo de una capa Dense
depende tanto de las formas de entrada como de la salida de la capa,. Por lo tanto, la forma de salida requerida como argumento del constructor no es información suficiente para crear la variable por sí sola. Dado que al llamar a una Layer
también se lee el valor de la variable, se debe realizar una restauración entre la creación de la variable y su primer uso.
Para admitir este modismo, tf.train.Checkpoint
pospone las restauraciones que aún no tienen una variable correspondiente.
Inspeccionar manualmente los puntos de verificación
tf.train.load_checkpoint
devuelve un CheckpointReader
que brinda acceso de nivel inferior al contenido del punto de verificación. Contiene asignaciones de la clave de cada variable, la forma y el dtype de cada variable en el punto de verificación. La clave de una variable es la ruta de su objeto, como en los gráficos anteriores.
Nota: No existe una estructura de nivel más superior que el punto de verificación. Solo conoce las rutas y los valores de las variables y no entiende de models
, layers
ni cómo están conectados.
Entonces, si quiere saber el valor de net.l1.kernel
, se puede obtener con el siguiente código:
También proporciona un método get_tensor
que le permite inspeccionar el valor de una variable:
Seguimiento de objetos
Los puntos de verificación guardan y restauran los valores de los objetos tf.Variable
al "hacer el seguimiento" de cualquier variable u objeto rastreable establecido en uno de sus atributos. Al ejecutar un guardado, las variables se recopilan de forma recursiva de todos los objetos rastreados accesibles.
Al igual que con las asignaciones directas de atributos como self.l1 = tf.keras.layers.Dense(5)
, la asignación de listas y diccionarios a atributos hará el seguimiento de su contenido.
Es posible que observe objetos empaquetadores de listas y diccionarios. Estos empaquetadores son versiones que pueden giardarse como puntos de verificación de las estructuras de datos subyacentes. Al igual que la carga basada en atributos, estos empaquetadores restauran el valor de una variable tan pronto como se agrega al contenedor.
Los objetos rastreables incluyen tf.train.Checkpoint
, tf.Module
y sus subclases (por ejemplo keras.layers.Layer
y keras.Model
) y contenedores de Python reconocidos:
dict
(ycollections.OrderedDict
)list
tuple
(ycollections.namedtuple
,typing.NamedTuple
)
Entre los tipos de contenedores que no se admiten están:
collections.defaultdict
set
Todos los demás objetos de Python se ignoran, entre ellos:
int
string
float
Resumen
Los objetos de TensorFlow proporcionan un mecanismo automático simple para guardar y restaurar los valores de las variables que usan.