Path: blob/master/site/pt-br/lite/examples/jax_conversion/overview.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
Conversão de modelos Jax para TF Lite
Visão geral
Observação: esta API é nova e só está disponível via pip install tf-nightly. Ela estará disponível no TensorFlow versão 2.7. Além disso, esta API ainda é experimental e está sujeita a mudanças.
Este CodeLab demonstra como criar um modelo para reconhecimento MNIST usando o Jax e como convertê-lo para o TensorFlow Lite. Além disso, demonstra como otimizar o modelo Jax convertido para TF Lite com quantização pós-treinamento.
Pré-requisitos
Recomenda-se usar este recurso com a build noturna mais recente do TensorFlow via pip.
Preparação dos dados
Baixe os dados MNIST com o dataset do Keras e faça o pré-processamento.
Compile o modelo MNIST com o Jax
Treine e avalie o modelo
Converta para um modelo do TF Lite
Note que nós:
Embutimos os parâmetros na função
predict
do Jax comfunctools.partial
.Criamos
jnp.zeros
, que é um tensor "temporário" usado para o Jax fazer o tracing do modelo.Chamamos
experimental_from_jax
:
A função
serving_func
é encapsulada em uma lista.A entrada é associada a um determinado nome e passada como um array encapsulado em uma lista.
Verifique o modelo convertido para TF Lite
Compare os resultados do modelo convertido com o modelo Jax.
Otimize o modelo
Forneceremos um representative_dataset
para fazer a quantização pós-treinamento a fim de otimizar o modelo.
Avalie o modelo otimizado
Compare o tamanho do modelo otimizado
Devemos observar que o modelo quantizado é quatro vezes menor do que o original.