Path: blob/master/site/es-419/lite/examples/jax_conversion/overview.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
Conversión del modelo Jax para TFLite
Visión general
Nota: Esta API es nueva y sólo está disponible a través de pip install tf-nightly. Estará disponible en la versión 2.7 de TensorFlow. Además, la API es aún experimental y está sujeta a cambios.
Este CodeLab demuestra cómo construir un modelo para el reconocimiento MNIST usando Jax, y cómo convertirlo a TensorFlow Lite. Este codelab también demostrará cómo optimizar el modelo TFLite convertido a Jax con cuantización post-entrenamiento.
Requisitos previos
Se recomienda probar esta característica con la más reciente compilación de TensorFlow nightly pip.
Preparación de datos
Descargue los datos MNIST con el conjunto de datos Keras y preprocese.
Genere el modelo MNIST con Jax
Entrene y evalúe el modelo
Convertir a modelo TFLite.
Tenga en cuenta que
Aplicamos los parámetros en línea a la func
predict
de Jax confunctools.partial
.Generamos un
jnp.zeros
, se trata de un tensor "marcador de posición" usado para que Jax trace el modelo.Llamamos a
experimental_from_jax
:
La
serving_func
se encapsula en una lista.La entrada se asocia a un nombre determinado y se pasa como un arreglo encapsulado en una lista.
Compruebe el modelo TFLite convertido
Compare los resultados del modelo convertido con el modelo Jax.
Optimizar el modelo
Daremos un representative_dataset
para hacer la cuantización postentrenamiento para optimizar el modelo.
Evaluar el modelo optimizado
Compare el tamaño del modelo cuantizado
Deberíamos poder ver que el modelo cuantizado es cuatro veces más pequeño que el modelo original.