Copyright 2023 The TensorFlow Authors.
Importe um modelo JAX usando JAX2TF
Este notebook fornece um exemplo completo e executável de como criar um modelo usando JAX e trazê-lo para o TensorFlow para continuar o treinamento. Isso é possível graças ao JAX2TF, uma API leve que fornece um caminho do ecossistema JAX para o ecossistema TensorFlow.
JAX é uma biblioteca de computação de arrays de alto desempenho. Para criar o modelo, este notebook usa o Flax, uma biblioteca de rede neural para JAX. Para treiná-lo, utiliza o Optax, uma biblioteca de otimização para JAX.
Se você é um pesquisador que usa JAX, o JAX2TF oferece um caminho para a produção usando as ferramentas testadas do TensorFlow.
Há muitas maneiras pelas quais isto pode ser útil, eis aqui algumas:
Inferência: pegar um modelo escrito para JAX e implantá-lo em servidor usando o TF Serving, em dispositivos usando TFLite ou na web usando TensorFlow.js.
Ajuste fino: tendo um modelo que foi treinado usando JAX, você pode levar seus componentes para o TF usando o JAX2TF e continuar treinando-o no TensorFlow com seus dados de treinamento e configuração existentes.
Fusão: combinando partes de modelos que foram treinados usando JAX com outros treinados usando TensorFlow, para máxima flexibilidade.
A chave para permitir esse tipo de interoperação entre JAX e TensorFlow é jax2tf.convert
, que utiliza componentes de modelo criados sobre JAX (sua função de perda, função de previsão, etc.) e cria representações equivalentes delas como funções do TensorFlow, que podem então ser exportadas como um SavedModel do TensorFlow.
Configuração
Baixe e prepare o dataset MNIST
Configure o treinamento
Este notebook criará e treinará um modelo simples para fins de demonstração.
Crie o modelo usando Flax
Escreva a função de passo de treinamento
Escreva o loop de treinamento
Crie o modelo e o otimizador (com Optax)
Treine o modelo
Treine parcialmente o modelo
Você continuará treinando o modelo no TensorFlow em breve.
Salve apenas o suficiente para inferência
Se seu objetivo é implantar seu modelo JAX (para que você possa executar inferência usando model.predict()
), simplesmente exportá-lo para SavedModel é suficiente. Esta seção demonstra como fazer isso.
Salve tudo
Se seu objetivo é uma exportação abrangente (útil se você planeja colocar o modelo no TensorFlow para ajuste fino, fusão etc.), esta seção demonstra como salvar o modelo para que você possa acessar métodos, incluindo:
model.predict
model.accuracy
model.loss (incluindo train=True/False bool, RNG para dropout e atualizações de estado BatchNorm)
Recarregue o modelo
Continue treinando o modelo JAX convertido no TensorFlow
Próximos passos
Você pode aprender mais sobre JAX e Flax em seus sites de documentação que contêm guias detalhados e exemplos. Se o JAX é novo para você, explore os tutoriais do JAX 101 e veja o início rápido do Flax. Para saber mais sobre como converter modelos JAX para o formato TensorFlow, veja o utilitário jax2tf no GitHub. Se você tiver interesse em converter modelos JAX para execução no navegador com TensorFlow.js, veja JAX na Web com TensorFlow.js. Se quiser preparar modelos JAX para execução no TensorFLow Lite, leia o guia Conversão de modelos JAX para o TFLite.