Copyright 2023 The TensorFlow Authors.
Cómo importar un modelo JAX mediante JAX2TF
Este bloc de notas proporciona un ejemplo completo y ejecutable para crear un modelo usando JAX y trasladarlo a TensorFlow con el fin de continuar el entrenamiento. Esto es posible gracias a JAX2TF, una API ligera que proporciona un medio para pasar del ecosistema JAX al ecosistema de TensorFlow.
JAX es una biblioteca de computación de arreglos de alto rendimiento. Para crear el modelo, este bloc de notas utiliza Flax, una biblioteca de redes neuronales para JAX. Para entrenarlo, utiliza Optax, una biblioteca de optimización para JAX.
Si eres un investigador que utiliza JAX, JAX2TF te ofrece un camino hacia la producción utilizando las herramientas ya demostradas de TensorFlow.
Esto puede ser útil de muchas maneras, aquí le presentamos algunas:
Inferencia: Tomar un modelo escrito para JAX e implementarlo en un servidor mediante TF Serving, en un dispositivo mediante TFLite o en la web mediante TensorFlow.js.
Ajuste fino: A partir de un modelo entrenado con JAX, puede llevar sus componentes a TF con JAX2TF y seguir entrenándolo en TensorFlow con los datos de entrenamiento y la configuración actuales.
Fusión: La combinación de partes de modelos que fueron entrenados usando JAX con los entrenados usando TensorFlow, para obtener la máxima flexibilidad.
La clave para permitir este tipo de interoperación entre JAX y TensorFlow es jax2tf.convert
, que toma componentes del modelo creados sobre JAX (su función de pérdida, función de predicción, etc) y crea representaciones equivalentes de ellos como funciones de TensorFlow, que luego se pueden exportar como un TensorFlow SavedModel.
Preparación
Descargue y prepare el conjunto de datos MNIST
Configurar el entrenamiento
Este bloc de notas creará y entrenará un modelo sencillo con fines de demostración.
Crear el modelo utilizando Flax
Escriba la función de escalón del entrenamiento
Escriba el bucle del entrenamiento
Cree el modelo y el optimizador (con Optax)
Entrenar al modelo
Entrene parcialmente al modelo
Continuará entrenando el modelo en TensorFlow enseguida.
Guarde lo justo para realizar inferencias
Si su objetivo es implementar su modelo JAX (para poder ejecutar la inferencia utilizando model.predict()
), basta con exportarlo a SavedModel. Esta sección muestra cómo hacerlo.
Guarde todo
Si tu objetivo es una llevar a cabo una exportación completa (útil si planea introducir el modelo en TensorFlow para su ajuste, fusión, etc.), esta sección muestra cómo guardar el modelo para que pueda acceder a métodos como:
model.predict
model.accuracy
model.loss (incluye bool train=True/False, RNG para realizar actualizaciones de estado en dropout y BatchNorm)
Volver a cargar el modelo
Continúe entrenando el modelo JAX convertido en TensorFlow
Siguientes pasos
Puede obtener más información sobre JAX y Flax en sus sitios web de la documentación que contienen guías detalladas y ejemplos. Si es nuevo en JAX, asegúrese de explorar los tutoriales JAX 101, y consulte Flax quickstart. Para obtener más información sobre la conversión de modelos JAX a formato TensorFlow, consulte la utilidad jax2tf en GitHub. Si está interesado en convertir modelos JAX para ejecutarlos en el navegador con TensorFlow.js, visite JAX en la web con TensorFlow.js. Si desea preparar modelos JAX para ejecutarlos en TensorFLow Lite, visite la guía Conversión de modelos JAX para TFLite.