Path: blob/master/site/es-419/guide/migrate/checkpoint_saver.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
Migrar el almacenamiento del punto de verificación
Guardar continuamente el "mejor" modelo o los pesos/parámetros del modelo tiene muchas ventajas. Por ejemplo, permite realizar un seguimiento del progreso del entrenamiento y cargar modelos guardados a partir de diferentes estados guardados.
En TensorFlow 1, para configurar guardar puntos de verificación durante el entrenamiento/validación con las APIs tf.estimator.Estimator
se especifica una programación en tf.estimator.RunConfig
o se utiliza tf.estimator.CheckpointSaverHook
. Esta guía muestra cómo migrar de este flujo de trabajo a las APIs de Keras de TensorFlow 2.
En TensorFlow 2, puede configurar tf.keras.callbacks.ModelCheckpoint
de varias maneras:
Guarda la "mejor" versión según una métrica monitorizada mediante el parámetro
save_best_only=True
, dondemonitor
puede ser, por ejemplo,'loss'
,'val_loss'
,'accuracy', or
'val_accuracy'`.Guardar continuamente a una frecuencia determinada (utilizando el argumento
save_freq
).Guarde solo los pesos/parámetros en vez de todo el modelo definiendo
save_weights_only
comoTrue
.
Para obtener más información, consulte los documentos de la API tf.keras.callbacks.ModelCheckpoint
y la sección Guardar puntos de verificación durante el entrenamiento del tutorial Guardar y cargar modelos. Aprenda más sobre el formato de puntos de verificación en la sección Formato de puntos de verificación de la guía Guardar y cargar modelos de Keras. Además, para agregar tolerancia ante errores, puede utilizar tf.keras.callbacks.BackupAndRestore
o tf.train.Checkpoint
para el punto de verificación manual. Obtenga más información en la Guía de migración de tolerancia ante errores.
Las retrollamadas de Keras son objetos que se llaman en diferentes puntos durante el entrenamiento/evaluación/predicción en las APIs incorporadas de Keras Model.fit
/Model.evaluate
/Model.predict
de las API. Obtenga más información en la sección Siguientes pasos al final de la guía.
Preparación
Empiece con imports y un conjunto de datos sencillo a modo de demostración:
TensorFlow 1: Guardar puntos de verificación con el tf.estimator de las API
Este ejemplo de TensorFlow 1 muestra cómo configurar tf.estimator.RunConfig
para guardar puntos de verificación en cada paso durante el entrenamiento/evaluación con las tf.estimator.Estimator
de las API:
TensorFlow 2: Guardar puntos de verificación con una retrollamada de Keras para Model.fit
En TensorFlow 2, cuando utilice el método incorporado de Keras Model.fit
(o Model.evaluate
) para entrenamiento/evaluación, puede configurar tf.keras.callbacks.ModelCheckpoint
y luego pasarlo al parámetro callbacks
de Model.fit
(o Model.evaluate
). (Obtenga más información en los documentos de la API y en la sección Uso de retrollamadas de la guía Entrenamiento y evaluación con los métodos incorporados).
En el siguiente ejemplo, se utilizará una retrollamada tf.keras.callbacks.ModelCheckpoint
para almacenar los puntos de verificación en un directorio temporal:
Siguientes pasos
Obtenga más información sobre los puntos de verificación en:
Documentos de la API:
tf.keras.callbacks.ModelCheckpoint
Tutorial: Guardar y cargar modelos (la sección Guardar puntos de verificación durante el entrenamiento)
Guía: Guardar y cargar modelos de Keras (la sección Formato de punto de verificación TF)
Obtenga más información sobre las retrollamadas en:
Documentos de la API:
tf.keras.callbacks.Callback
Guía: Entrenamiento y evaluación con los métodos incorporados (la sección Usar retrollamadas)
También le pueden resultar útiles los siguientes recursos relacionados con la migración:
La Guía de migración de tolerancia ante errores:
tf.keras.callbacks.BackupAndRestore
paraModel.fit
, otf.train.Checkpoint
ytf.train.CheckpointManager
de las API para un bucle de entrenamiento personalizadoLa Guía de migración de parada anticipada:
tf.keras.callbacks.EarlyStopping
es una retrollamada de parada anticipada incorporada.La Guía de migración de TensorBoard: TensorBoard permite el seguimiento y la visualización de métricas
La Guía de migración de retrollamadas LoggingTensorHook y StopAtStepHook a Keras