Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/es-419/lite/examples/on_device_training/overview.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Entrenamiento en el dispositivo con TensorFlow Lite

Al implementar un modelo de aprendizaje automático de TensorFlow Lite en un dispositivo o app móvil, es posible que desee permitir que el modelo se mejore o personalice en función de las aportaciones del dispositivo o del usuario final. Usar técnicas de entrenamiento en el dispositivo le permite actualizar un modelo sin que los datos salgan de los dispositivos de sus usuarios, mejorando la privacidad del usuario y sin necesidad de que los usuarios actualicen el software del dispositivo.

Por ejemplo, puede tener un modelo en su app móvil que reconozca artículos de moda, pero quiere que los usuarios mejoren el rendimiento del reconocimiento con el tiempo en función de sus intereses. Habilitar el entrenamiento en el dispositivo permite a los usuarios interesados en el calzado mejorar en el reconocimiento de un estilo o marca de zapatos concretos cuanto más a menudo usen su app.

Este tutorial le muestra cómo construir un modelo TensorFlow Lite que puede ser entrenado y mejorado de forma incremental dentro de una app Android instalada.

Nota: La técnica de entrenamiento en el dispositivo puede añadirse a las implementaciones existentes de TensorFlow Lite, siempre que los dispositivos a los que se dirija admitan el almacenamiento local de archivos.

Configuración

Este tutorial usa Python para entrenar y convertir un modelo TensorFlow antes de incorporarlo a una app Android. Empiece instalando e importando los siguientes paquetes.

import matplotlib.pyplot as plt import numpy as np import tensorflow as tf print("TensorFlow version:", tf.__version__)
TensorFlow version: 2.8.0

Nota: Las API de entrenamiento en el dispositivo están disponibles a partir de la versión 2.7 de TensorFlow.

Clasifique imágenes de prendas de vestir

Este código de ejemplo usa el conjunto de datos Fashion MNIST para entrenar un modelo de red neuronal para clasificar imágenes de ropa. Este conjunto de datos contiene 60,000 imágenes pequeñas (28 x 28 pixel) en escala de grises que contienen 10 categorías diferentes de accesorios de moda, incluidos vestidos, camisas y sandalias.

<figure> <img src="https://tensorflow.org/images/fashion-mnist-sprite.png" alt="Fashion MNIST images"> <figcaption><b>Figure 1</b>: <a href="https://github.com/zalandoresearch/fashion-mnist">Fashion-MNIST samples</a> (by Zalando, MIT License).</figcaption> </figure>

Puede explorar este conjunto de datos en mayor profundidad en el Tutorial de clasificación de Keras.

Genere un modelo para el entrenamiento en el dispositivo

Los modelos TensorFlow Lite suelen tener un único método (o firma) de función expuesto que le permite llamar al modelo para ejecutar una inferencia. Para que un modelo pueda ser entrenado y usado en un dispositivo, debe ser capaz de realizar varias operaciones separadas, incluyendo las funciones de entrenamiento, inferencia, guardado y restauración del modelo. Puede habilitar esta funcionalidad ampliando primero su modelo TensorFlow para que tenga varias funciones, y luego exponiendo esas funciones como firmas cuando convierta su modelo al formato de modelo TensorFlow Lite.

El siguiente ejemplo de código muestra cómo añadir las siguientes funciones a un modelo TensorFlow:

  • la función train entrena el modelo con datos de entrenamiento.

  • la función infer invoca la inferencia.

  • la función save guarda las ponderaciones entrenables en el sistema de archivos.

  • la función restore carga las ponderaciones entrenables desde el sistema de archivos.

IMG_SIZE = 28 class Model(tf.Module): def __init__(self): self.model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE), name='flatten'), tf.keras.layers.Dense(128, activation='relu', name='dense_1'), tf.keras.layers.Dense(10, name='dense_2') ]) self.model.compile( optimizer='sgd', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)) # The `train` function takes a batch of input images and labels. @tf.function(input_signature=[ tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32), tf.TensorSpec([None, 10], tf.float32), ]) def train(self, x, y): with tf.GradientTape() as tape: prediction = self.model(x) loss = self.model.loss(y, prediction) gradients = tape.gradient(loss, self.model.trainable_variables) self.model.optimizer.apply_gradients( zip(gradients, self.model.trainable_variables)) result = {"loss": loss} return result @tf.function(input_signature=[ tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32), ]) def infer(self, x): logits = self.model(x) probabilities = tf.nn.softmax(logits, axis=-1) return { "output": probabilities, "logits": logits } @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) def save(self, checkpoint_path): tensor_names = [weight.name for weight in self.model.weights] tensors_to_save = [weight.read_value() for weight in self.model.weights] tf.raw_ops.Save( filename=checkpoint_path, tensor_names=tensor_names, data=tensors_to_save, name='save') return { "checkpoint_path": checkpoint_path } @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) def restore(self, checkpoint_path): restored_tensors = {} for var in self.model.weights: restored = tf.raw_ops.Restore( file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype, name='restore') var.assign(restored) restored_tensors[var.name] = restored return restored_tensors

La función train del código anterior usa la clase GradientTape para registrar las operaciones de diferenciación automática. Para más información sobre cómo usar esta clase, consulte la Introducción a los gradientes y la diferenciación automática.

Usted podría usar el método Model.train_step del modelo Keras aquí en lugar de una implementación desde cero. Sólo tenga en cuenta que la pérdida (y las métricas) devueltas por Model.train_step es el promedio actual, y debe restablecerse regularmente (normalmente cada época). Consulte Personalizar Model.fit para más detalles.

Nota: Las ponderaciones generadas por este modelo se serializan en un archivo de puntos de verificación de formato TensorFlow 1.

Prepare los datos

Obtenga el conjunto de datos MNIST de moda para el entrenamiento de su modelo.

fashion_mnist = tf.keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

Preprocese el conjunto de datos

Los valores de los pixel en este conjunto de datos están entre 0 y 255, y deben ser normalizados a un valor entre 0 y 1 para ser procesados por el modelo. Divida los valores por 255 para realizar este ajuste.

train_images = (train_images / 255.0).astype(np.float32) test_images = (test_images / 255.0).astype(np.float32)

Convierta las etiquetas de los datos en valores categóricos realizando una codificación en un solo paso.

train_labels = tf.keras.utils.to_categorical(train_labels) test_labels = tf.keras.utils.to_categorical(test_labels)

Nota: Asegúrese de preprocesar sus conjuntos de datos de entrenamiento y de prueba del mismo modo, para que sus pruebas evalúen con precisión el rendimiento de su modelo.

Entrene el modelo

Antes de convertir y configurar su modelo TensorFlow Lite, complete el entrenamiento inicial de su modelo usando el conjunto de datos preprocesados y el método de firma train. El siguiente código ejecuta el entrenamiento del modelo durante 100 épocas, procesando lotes de 100 imágenes a la vez y mostrando el valor de pérdida después de cada 10 épocas. Dado que esta ejecución de entrenamiento procesa bastantes datos, puede tardar unos minutos en finalizar.

NUM_EPOCHS = 100 BATCH_SIZE = 100 epochs = np.arange(1, NUM_EPOCHS + 1, 1) losses = np.zeros([NUM_EPOCHS]) m = Model() train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) train_ds = train_ds.batch(BATCH_SIZE) for i in range(NUM_EPOCHS): for x,y in train_ds: result = m.train(x, y) losses[i] = result['loss'] if (i + 1) % 10 == 0: print(f"Finished {i+1} epochs") print(f" loss: {losses[i]:.3f}") # Save the trained weights to a checkpoint. m.save('/tmp/model.ckpt')
Finished 10 epochs loss: 0.428 Finished 20 epochs loss: 0.378 Finished 30 epochs loss: 0.344 Finished 40 epochs loss: 0.317 Finished 50 epochs loss: 0.299 Finished 60 epochs loss: 0.283 Finished 70 epochs loss: 0.266 Finished 80 epochs loss: 0.252 Finished 90 epochs loss: 0.240 Finished 100 epochs loss: 0.230
{'checkpoint_path': <tf.Tensor: shape=(), dtype=string, numpy=b'/tmp/model.ckpt'>}
plt.plot(epochs, losses, label='Pre-training') plt.ylim([0, max(plt.ylim())]) plt.xlabel('Epoch') plt.ylabel('Loss [Cross Entropy]') plt.legend();
Image in a Jupyter notebook

Nota: Debe completar el entrenamiento inicial de su modelo antes de convertirlo al formato TensorFlow Lite, para que el modelo tenga un conjunto inicial de ponderaciones y sea capaz de realizar inferencias razonables antes de empezar a recopilar datos y realizar ejecuciones de entrenamiento en el dispositivo.

Convierta el modelo al formato TensorFlow Lite

Una vez que haya ampliado su modelo TensorFlow para habilitar funciones adicionales para el entrenamiento en el dispositivo y haya completado el entrenamiento inicial del modelo, puede convertirlo al formato TensorFlow Lite. El siguiente código convierte y guarda su modelo a ese formato, incluyendo el conjunto de firmas que se usan con el modelo TensorFlow Lite en un dispositivo: train, infer, save, restore.

SAVED_MODEL_DIR = "saved_model" tf.saved_model.save( m, SAVED_MODEL_DIR, signatures={ 'train': m.train.get_concrete_function(), 'infer': m.infer.get_concrete_function(), 'save': m.save.get_concrete_function(), 'restore': m.restore.get_concrete_function(), }) # Convert the model converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] converter.experimental_enable_resource_variables = True tflite_model = converter.convert()

Configure las firmas de TensorFlow Lite

El modelo TensorFlow Lite que guardó en el paso anterior contiene varias firmas de funciones. Puede acceder a ellas a través de la clase tf.lite.Interpreter e invocar cada firma restore, train, save, y infer por separado.

interpreter = tf.lite.Interpreter(model_content=tflite_model) interpreter.allocate_tensors() infer = interpreter.get_signature_runner("infer")

Compare los resultados del modelo original y del modelo lite convertido:

logits_original = m.infer(x=train_images[:1])['logits'][0] logits_lite = infer(x=train_images[:1])['logits'][0]
#@title def compare_logits(logits): width = 0.35 offset = width/2 assert len(logits)==2 keys = list(logits.keys()) plt.bar(x = np.arange(len(logits[keys[0]]))-offset, height=logits[keys[0]], width=0.35, label=keys[0]) plt.bar(x = np.arange(len(logits[keys[1]]))+offset, height=logits[keys[1]], width=0.35, label=keys[1]) plt.legend() plt.grid(True) plt.ylabel('Logit') plt.xlabel('ClassID') delta = np.sum(np.abs(logits[keys[0]] - logits[keys[1]])) plt.title(f"Total difference: {delta:.3g}") compare_logits({'Original': logits_original, 'Lite': logits_lite})
Image in a Jupyter notebook

Arriba puede ver que el comportamiento del modelo no cambia por la conversión a TFLite.

Reentrene el modelo en un dispositivo

Tras convertir su modelo a TensorFlow Lite e implementarlo con su app, puede volver a entrenar el modelo en un dispositivo utilizando nuevos datos y el método de firma train de su modelo. Cada ejecución de entrenamiento genera un nuevo conjunto de ponderaciones que puede guardar para reutilizar y seguir mejorando el modelo, como se muestra en la siguiente sección.

Nota: Dado que las tareas de entrenamiento consumen muchos recursos, debería considerar realizarlas cuando los usuarios no estén interactuando activamente con el dispositivo, y como proceso en segundo plano. Considere usar la API WorkManager para programar el reentrenamiento del modelo como una tarea asíncrona.

En Android, puede realizar el entrenamiento en el dispositivo con TensorFlow Lite usando las API de Java o C++. En Java, use la clase Interpreter para cargar un modelo y conducir las tareas de entrenamiento del modelo. El siguiente ejemplo muestra cómo ejecutar el procedimiento de entrenamiento usando el método runSignature:

try (Interpreter interpreter = new Interpreter(modelBuffer)) { int NUM_EPOCHS = 100; int BATCH_SIZE = 100; int IMG_HEIGHT = 28; int IMG_WIDTH = 28; int NUM_TRAININGS = 60000; int NUM_BATCHES = NUM_TRAININGS / BATCH_SIZE; List<FloatBuffer> trainImageBatches = new ArrayList<>(NUM_BATCHES); List<FloatBuffer> trainLabelBatches = new ArrayList<>(NUM_BATCHES); // Prepare training batches. for (int i = 0; i < NUM_BATCHES; ++i) { FloatBuffer trainImages = FloatBuffer.allocateDirect(BATCH_SIZE * IMG_HEIGHT * IMG_WIDTH).order(ByteOrder.nativeOrder()); FloatBuffer trainLabels = FloatBuffer.allocateDirect(BATCH_SIZE * 10).order(ByteOrder.nativeOrder()); // Fill the data values... trainImageBatches.add(trainImages.rewind()); trainImageLabels.add(trainLabels.rewind()); } // Run training for a few steps. float[] losses = new float[NUM_EPOCHS]; for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) { for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) { Map<String, Object> inputs = new HashMap<>(); inputs.put("x", trainImageBatches.get(batchIdx)); inputs.put("y", trainLabelBatches.get(batchIdx)); Map<String, Object> outputs = new HashMap<>(); FloatBuffer loss = FloatBuffer.allocate(1); outputs.put("loss", loss); interpreter.runSignature(inputs, outputs, "train"); // Record the last loss. if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0); } // Print the loss output for every 10 epochs. if ((epoch + 1) % 10 == 0) { System.out.println( "Finished " + (epoch + 1) + " epochs, current loss: " + loss.get(0)); } } // ... }

Puede ver un ejemplo de código completo de reentrenamiento de modelos dentro de una app Android en la demo model personalization demo app.

Ejecute el entrenamiento durante algunas épocas para mejorar o personalizar el modelo. En la práctica, usted ejecutaría este entrenamiento adicional usando los datos recopilados en el dispositivo. Para simplificar, este ejemplo usa los mismos datos de entrenamiento que el paso de entrenamiento anterior.

train = interpreter.get_signature_runner("train") NUM_EPOCHS = 50 BATCH_SIZE = 100 more_epochs = np.arange(epochs[-1]+1, epochs[-1] + NUM_EPOCHS + 1, 1) more_losses = np.zeros([NUM_EPOCHS]) for i in range(NUM_EPOCHS): for x,y in train_ds: result = train(x=x, y=y) more_losses[i] = result['loss'] if (i + 1) % 10 == 0: print(f"Finished {i+1} epochs") print(f" loss: {more_losses[i]:.3f}")
Finished 10 epochs loss: 0.223 Finished 20 epochs loss: 0.216 Finished 30 epochs loss: 0.210 Finished 40 epochs loss: 0.204 Finished 50 epochs loss: 0.198
plt.plot(epochs, losses, label='Pre-training') plt.plot(more_epochs, more_losses, label='On device') plt.ylim([0, max(plt.ylim())]) plt.xlabel('Epoch') plt.ylabel('Loss [Cross Entropy]') plt.legend();
Image in a Jupyter notebook

Arriba puede ver que el entrenamiento en el dispositivo se retoma exactamente donde se detuvo el preentrenamiento.

Guarde las ponderaciones entrenadas

Cuando se completa una ejecución de entrenamiento en un dispositivo, el modelo actualiza el conjunto de ponderaciones que está usando en la memoria. Usando el método de firma save que creó en su modelo TensorFlow Lite, puede guardar estas ponderaciones en un archivo de punto de verificación para reutilizarlas posteriormente y mejorar su modelo.

save = interpreter.get_signature_runner("save") save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'checkpoint_path': array(b'/tmp/model.ckpt', dtype=object)}

En su aplicación Android, puede almacenar las ponderaciones generadas como un archivo de punto de verificación en el espacio de almacenamiento interno asignado a su app.

try (Interpreter interpreter = new Interpreter(modelBuffer)) { // Conduct the training jobs. // Export the trained weights as a checkpoint file. File outputFile = new File(getFilesDir(), "checkpoint.ckpt"); Map&lt;String, Object&gt; inputs = new HashMap&lt;&gt;(); inputs.put("checkpoint_path", outputFile.getAbsolutePath()); Map&lt;String, Object&gt; outputs = new HashMap&lt;&gt;(); interpreter.runSignature(inputs, outputs, "save"); }

Recupere las ponderaciones entrenadas

Cada vez que cree un intérprete a partir de un modelo TFLite, el intérprete cargará inicialmente las ponderaciones originales del modelo.

Así que después de haber realizado algún entrenamiento y guardado un archivo de punto de verificación, tendrá que ejecutar el método de firma restore para cargar el punto de verificación.

Una buena regla es "Cada vez que cree un intérprete para un modelo, si el punto de verificación existe, cárguelo". Si necesita restablecer el modelo al comportamiento de la línea de referencia, sólo tiene que borrar el punto de verificación y crear un intérprete nuevo.

another_interpreter = tf.lite.Interpreter(model_content=tflite_model) another_interpreter.allocate_tensors() infer = another_interpreter.get_signature_runner("infer") restore = another_interpreter.get_signature_runner("restore")
logits_before = infer(x=train_images[:1])['logits'][0] # Restore the trained weights from /tmp/model.ckpt restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_)) logits_after = infer(x=train_images[:1])['logits'][0] compare_logits({'Before': logits_before, 'After': logits_after})
Image in a Jupyter notebook

El punto de verificación se generó entrenando y guardando con TFLite. Arriba puede ver que al aplicar el punto de verificación se actualiza el comportamiento del modelo.

Nota: La carga de las ponderaciones guardadas del punto de verificación puede llevar tiempo, en función del número de variables del modelo y del tamaño del archivo del punto de verificación.

En su app para Android, puede restaurar las ponderaciones serializadas y entrenadas desde el archivo de puntos de verificación que almacenó anteriormente.

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) { // Load the trained weights from the checkpoint file. File outputFile = new File(getFilesDir(), "checkpoint.ckpt"); Map<String, Object> inputs = new HashMap<>(); inputs.put("checkpoint_path", outputFile.getAbsolutePath()); Map<String, Object> outputs = new HashMap<>(); anotherInterpreter.runSignature(inputs, outputs, "restore"); }

Nota: Cuando se reinicie su aplicación, deberá volver a cargar las ponderaciones entrenadas antes de ejecutar nuevas inferencias.

Ejecute la inferencia usando ponderaciones entrenadas

Una vez que haya cargado las ponderaciones guardadas previamente desde un archivo de punto de verificación, la ejecución del método infer usa esas ponderaciones con su modelo original para mejorar las predicciones. Después de cargar las ponderaciones guardadas, puede usar el método de firma infer como se muestra a continuación.

Nota: No es necesario cargar las ponderaciones guardadas para ejecutar una inferencia, pero ejecutarla en esa configuración produce predicciones usando el modelo entrenado originalmente, sin mejoras.

infer = another_interpreter.get_signature_runner("infer") result = infer(x=test_images) predictions = np.argmax(result["output"], axis=1) true_labels = np.argmax(test_labels, axis=1)
result['output'].shape
(10000, 10)

Trace las etiquetas predichas.

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] def plot(images, predictions, true_labels): plt.figure(figsize=(10,10)) for i in range(25): plt.subplot(5,5,i+1) plt.xticks([]) plt.yticks([]) plt.grid(False) plt.imshow(images[i], cmap=plt.cm.binary) color = 'b' if predictions[i] == true_labels[i] else 'r' plt.xlabel(class_names[predictions[i]], color=color) plt.show() plot(test_images, predictions, true_labels)
Image in a Jupyter notebook
predictions.shape
(10000,)

En su aplicación Android, tras restaurar las ponderaciones entrenadas, ejecute las inferencias basadas en los datos cargados.

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) { // Restore the weights from the checkpoint file. int NUM_TESTS = 10; FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder()); FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder()); // Fill the test data. // Run the inference. Map<String, Object> inputs = new HashMap<>(); inputs.put("x", testImages.rewind()); Map<String, Object> outputs = new HashMap<>(); outputs.put("output", output); anotherInterpreter.runSignature(inputs, outputs, "infer"); output.rewind(); // Process the result to get the final category values. int[] testLabels = new int[NUM_TESTS]; for (int i = 0; i < NUM_TESTS; ++i) { int index = 0; for (int j = 1; j < 10; ++j) { if (output.get(i * 10 + index) < output.get(i * 10 + j)) index = testLabels[j]; } testLabels[i] = index; } }

¡Felicidades! Acaba de crear un modelo TensorFlow Lite compatible con el entrenamiento en el dispositivo. Para más detalles de codificación, consulte el ejemplo de implementación en la app de demostración de personalización de modelos.

Si está interesado en aprender más sobre la clasificación de imágenes, consulte el Tutorial de clasificación de Keras en la página de la guía oficial de TensorFlow. Este tutorial se basa en ese ejercicio y profundiza en el tema de la clasificación.