Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/es-419/guide/jax2tf.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Cómo importar un modelo JAX mediante JAX2TF

Este bloc de notas proporciona un ejemplo completo y ejecutable para crear un modelo usando JAX y trasladarlo a TensorFlow con el fin de continuar el entrenamiento. Esto es posible gracias a JAX2TF, una API ligera que proporciona un medio para pasar del ecosistema JAX al ecosistema de TensorFlow.

JAX es una biblioteca de computación de arreglos de alto rendimiento. Para crear el modelo, este bloc de notas utiliza Flax, una biblioteca de redes neuronales para JAX. Para entrenarlo, utiliza Optax, una biblioteca de optimización para JAX.

Si eres un investigador que utiliza JAX, JAX2TF te ofrece un camino hacia la producción utilizando las herramientas ya demostradas de TensorFlow.

Esto puede ser útil de muchas maneras, aquí le presentamos algunas:

  • Inferencia: Tomar un modelo escrito para JAX e implementarlo en un servidor mediante TF Serving, en un dispositivo mediante TFLite o en la web mediante TensorFlow.js.

  • Ajuste fino: A partir de un modelo entrenado con JAX, puede llevar sus componentes a TF con JAX2TF y seguir entrenándolo en TensorFlow con los datos de entrenamiento y la configuración actuales.

  • Fusión: La combinación de partes de modelos que fueron entrenados usando JAX con los entrenados usando TensorFlow, para obtener la máxima flexibilidad.

La clave para permitir este tipo de interoperación entre JAX y TensorFlow es jax2tf.convert, que toma componentes del modelo creados sobre JAX (su función de pérdida, función de predicción, etc) y crea representaciones equivalentes de ellos como funciones de TensorFlow, que luego se pueden exportar como un TensorFlow SavedModel.

Preparación

import tensorflow as tf import numpy as np import jax import jax.numpy as jnp import flax import optax import os from matplotlib import pyplot as plt from jax.experimental import jax2tf from threading import Lock # Only used in the visualization utility. from functools import partial
# Needed for TensorFlow and JAX to coexist in GPU memory. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false" gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: # Memory growth must be set before GPUs have been initialized. print(e)
#@title Visualization utilities plt.rcParams["figure.figsize"] = (20,8) # The utility for displaying training and validation curves. def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10): ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs) # The losses. ax = plt.subplot(121) if loss is not None: x = np.arange(len(loss)) / steps_per_epochs #* epochs ax.plot(x, loss) ax.plot(range(1, epochs+1), avg_loss, "-o", linewidth=3) ax.plot(range(1, epochs+1), eval_loss, "-o", linewidth=3) ax.set_title('Loss') ax.set_ylabel('loss') ax.set_xlabel('epoch') if loss is not None: ax.set_ylim(0, np.max(loss[ignore_first_n:])) ax.legend(['train', 'avg train', 'eval']) else: ymin = np.min(avg_loss[ignore_first_n_epochs:]) ymax = np.max(avg_loss[ignore_first_n_epochs:]) ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10) ax.legend(['avg train', 'eval']) # The accuracy. ax = plt.subplot(122) ax.set_title('Eval Accuracy') ax.set_ylabel('accuracy') ax.set_xlabel('epoch') ymin = np.min(eval_accuracy[ignore_first_n_epochs:]) ymax = np.max(eval_accuracy[ignore_first_n_epochs:]) ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10) ax.plot(range(1, epochs+1), eval_accuracy, "-o", linewidth=3) class Progress: """Text mode progress bar. Usage: p = Progress(30) p.step() p.step() p.step(reset=True) # to restart form 0% The progress bar displays a new header at each restart.""" def __init__(self, maxi, size=100, msg=""): """ :param maxi: the number of steps required to reach 100% :param size: the number of characters taken on the screen by the progress bar :param msg: the message displayed in the header of the progress bar """ self.maxi = maxi self.p = self.__start_progress(maxi)() # `()`: to get the iterator from the generator. self.header_printed = False self.msg = msg self.size = size self.lock = Lock() def step(self, reset=False): with self.lock: if reset: self.__init__(self.maxi, self.size, self.msg) if not self.header_printed: self.__print_header() next(self.p) def __print_header(self): print() format_string = "0%{: ^" + str(self.size - 6) + "}100%" print(format_string.format(self.msg)) self.header_printed = True def __start_progress(self, maxi): def print_progress(): # Bresenham's algorithm. Yields the number of dots printed. # This will always print 100 dots in max invocations. dx = maxi dy = self.size d = dy - dx for x in range(maxi): k = 0 while d >= 0: print('=', end="", flush=True) k += 1 d -= dx d += dy yield k # Keep yielding the last result if there are too many steps. while True: yield k return print_progress

Descargue y prepare el conjunto de datos MNIST

(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data() train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels)) train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1), tf.one_hot(y, depth=10))) BATCH_SIZE = 256 train_data = train_data.batch(BATCH_SIZE, drop_remainder=True) train_data = train_data.cache() train_data = train_data.shuffle(5000, reshuffle_each_iteration=True) test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels)) test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1), tf.one_hot(y, depth=10))) test_data = test_data.batch(10000) test_data = test_data.cache() (one_batch, one_batch_labels) = next(iter(train_data)) # just one batch (all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000

Configurar el entrenamiento

Este bloc de notas creará y entrenará un modelo sencillo con fines de demostración.

# Training hyperparameters. JAX_EPOCHS = 3 TF_EPOCHS = 7 STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE LEARNING_RATE = 0.01 LEARNING_RATE_EXP_DECAY = 0.6 # The learning rate schedule for JAX (with Optax). jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True) # THe learning rate schedule for TensorFlow. tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

Crear el modelo utilizando Flax

class ConvModel(flax.linen.Module): @flax.linen.compact def __call__(self, x, train): x = flax.linen.Conv(features=12, kernel_size=(3,3), padding="SAME", use_bias=False)(x) x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x) x = x.reshape((x.shape[0], -1)) # flatten x = flax.linen.Dense(features=200, use_bias=True)(x) x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x) x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x) x = flax.linen.relu(x) x = flax.linen.Dense(features=10)(x) #x = flax.linen.log_softmax(x) return x # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number). # `jax.grad` will differentiate it against the fist argument. # The user must split trainable and non-trainable variables into `params` and `other_state`. # Must pass a different RNG key each time for the dropout mask to be different. def loss(self, params, other_state, rng, data, labels, train): logits, batch_stats = self.apply({'params': params, **other_state}, data, mutable=['batch_stats'], rngs={'dropout': rng}, train=train) # The loss averaged across the batch dimension. loss = optax.softmax_cross_entropy(logits, labels).mean() return loss, batch_stats def predict(self, state, data): logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False) probabilities = flax.linen.log_softmax(logits) return probabilities def accuracy(self, state, data, labels): probabilities = self.predict(state, data) predictions = jnp.argmax(probabilities, axis=-1) dense_labels = jnp.argmax(labels, axis=-1) accuracy = jnp.equal(predictions, dense_labels).mean() return accuracy

Escriba la función de escalón del entrenamiento

# The training step. @partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model def train_step(model, state, optimizer_state, rng, data, labels): other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True) updates, optimizer_state = optimizer.update(grads, optimizer_state) params = optax.apply_updates(params, updates) new_state = state.copy(add_or_replace={**batch_stats, 'params': params}) rng, _ = jax.random.split(rng) return new_state, optimizer_state, rng, loss

Escriba el bucle del entrenamiento

def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies): p = Progress(STEPS_PER_EPOCH) rng = jax.random.PRNGKey(0) for epoch in range(epochs): # This is where the learning rate schedule state is stored in the optimizer state. optimizer_step = optimizer_state[1].count # Run an epoch of training. for step, (data, labels) in enumerate(train_data): p.step(reset=(step==0)) state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy()) losses.append(loss) avg_loss = np.mean(losses[-step:]) avg_losses.append(avg_loss) # Run one epoch of evals (10,000 test images in a single batch). other_state, params = state.pop('params') # Gotcha: must discard modified batch_stats here eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False) eval_losses.append(eval_loss) eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy()) eval_accuracies.append(eval_accuracy) print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", jlr_decay(optimizer_step)) return state, optimizer_state

Cree el modelo y el optimizador (con Optax)

# The model. model = ConvModel() state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout" # The optimizer. optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule optimizer_state = optimizer.init(state['params']) losses=[] avg_losses=[] eval_losses=[] eval_accuracies=[]

Entrenar al modelo

new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

Entrene parcialmente al modelo

Continuará entrenando el modelo en TensorFlow enseguida.

model = ConvModel() state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout" # The optimizer. optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule optimizer_state = optimizer.init(state['params']) losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []
state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

Guarde lo justo para realizar inferencias

Si su objetivo es implementar su modelo JAX (para poder ejecutar la inferencia utilizando model.predict()), basta con exportarlo a SavedModel. Esta sección muestra cómo hacerlo.

# Test data with a different batch size to test polymorphic shapes. x, y = next(iter(train_data.unbatch().batch(13))) m = tf.Module() # Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function. state_vars = tf.nest.map_structure(tf.Variable, state) # Keep the wrapped state as flat list (needed in TensorFlow fine-tuning). m.vars = tf.nest.flatten(state_vars) # Convert the desired JAX function (`model.predict`). predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"]) # Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work). @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)]) def predict(data): return predict_fn(state_vars, data) m.predict = predict tf.saved_model.save(m, "./")
# Test the converted function. print("Converted function predictions:", np.argmax(m.predict(x).numpy(), axis=-1)) # Reload the model. reloaded_model = tf.saved_model.load("./") # Test the reloaded converted function (the result should be the same). print("Reloaded function predictions:", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

Guarde todo

Si tu objetivo es una llevar a cabo una exportación completa (útil si planea introducir el modelo en TensorFlow para su ajuste, fusión, etc.), esta sección muestra cómo guardar el modelo para que pueda acceder a métodos como:

  • model.predict

  • model.accuracy

  • model.loss (incluye bool train=True/False, RNG para realizar actualizaciones de estado en dropout y BatchNorm)

from collections import abc def _fix_frozen(d): """Changes any mappings (e.g. frozendict) back to dict.""" if isinstance(d, list): return [_fix_frozen(v) for v in d] elif isinstance(d, tuple): return tuple(_fix_frozen(v) for v in d) elif not isinstance(d, abc.Mapping): return d d = dict(d) for k, v in d.items(): d[k] = _fix_frozen(v) return d
class TFModel(tf.Module): def __init__(self, state, model): super().__init__() # Special care needed for the train=True/False parameter in the loss @jax.jit def loss_with_train_bool(state, rng, data, labels, train): other_state, params = state.pop('params') loss, batch_stats = jax.lax.cond(train, lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True), lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False), state, data, labels) # must use JAX to split the RNG, therefore, must do it in a @jax.jit function new_rng, _ = jax.random.split(rng) return loss, batch_stats, new_rng self.state_vars = tf.nest.map_structure(tf.Variable, state) self.vars = tf.nest.flatten(self.state_vars) self.jax_rng = tf.Variable(jax.random.PRNGKey(0)) self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=["...", "...", "(b, 28, 28, 1)", "(b, 10)", "..."]) self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=["...", "(b, 28, 28, 1)", "(b, 10)"]) self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"]) # Must specify TensorSpec manually for variable batch size to work @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)]) def predict(self, data): # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly # otherwise, all model weights would be embedded in the TF graph as constants. return self.predict_fn(self.state_vars, data) @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def train_loss(self, data, labels): loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True) # update batch norm stats flat_vars = tf.nest.flatten(self.state_vars['batch_stats']) flat_values = tf.nest.flatten(batch_stats['batch_stats']) for var, val in zip(flat_vars, flat_values): var.assign(val) # update RNG self.jax_rng.assign(new_rng) return loss @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def eval_loss(self, data, labels): loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False) return loss @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32), tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], autograph=False) def accuracy(self, data, labels): return self.accuracy_fn(self.state_vars, data, labels)
# Instantiate the model. tf_model = TFModel(state, model) # Save the model. tf.saved_model.save(tf_model, "./")

Volver a cargar el modelo

reloaded_model = tf.saved_model.load("./") # Test if it works and that the batch size is indeed variable. x,y = next(iter(train_data.unbatch().batch(13))) print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1)) x,y = next(iter(train_data.unbatch().batch(20))) print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1)) print(reloaded_model.accuracy(one_batch, one_batch_labels)) print(reloaded_model.accuracy(all_test_data, all_test_labels))

Continúe entrenando el modelo JAX convertido en TensorFlow

optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay) # Set the iteration step for the learning rate to resume from where it left off in JAX. optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH) p = Progress(STEPS_PER_EPOCH) for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS): # This is where the learning rate schedule state is stored in the optimizer state. optimizer_step = optimizer.iterations for step, (data, labels) in enumerate(train_data): p.step(reset=(step==0)) with tf.GradientTape() as tape: #loss = reloaded_model.loss(data, labels, True) loss = reloaded_model.train_loss(data, labels) grads = tape.gradient(loss, reloaded_model.vars) optimizer.apply_gradients(zip(grads, reloaded_model.vars)) losses.append(loss) avg_loss = np.mean(losses[-step:]) avg_losses.append(avg_loss) eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy() eval_losses.append(eval_loss) eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy() eval_accuracies.append(eval_accuracy) print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", tflr_decay(optimizer.iterations).numpy())
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH) # The loss takes a hit when the training restarts, but does not go back to random levels. # This is likely caused by the optimizer momentum being reinitialized.

Siguientes pasos

Puede obtener más información sobre JAX y Flax en sus sitios web de la documentación que contienen guías detalladas y ejemplos. Si es nuevo en JAX, asegúrese de explorar los tutoriales JAX 101, y consulte Flax quickstart. Para obtener más información sobre la conversión de modelos JAX a formato TensorFlow, consulte la utilidad jax2tf en GitHub. Si está interesado en convertir modelos JAX para ejecutarlos en el navegador con TensorFlow.js, visite JAX en la web con TensorFlow.js. Si desea preparar modelos JAX para ejecutarlos en TensorFLow Lite, visite la guía Conversión de modelos JAX para TFLite.