Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/es-419/hub/tutorials/cropnet_cassava.ipynb
25118 views
Kernel: Python 3
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ==============================================================================

CropNet: Cassava Disease Detection

En estas notas se muestra cómo usar el modelo clasificador de enfermedades en mandioca CropNet de TensorFlow Hub. El modelo clasifica imágenes de mandioca en una de las siguientes 6 clases: plaga bacteriana, enfermedad del virus de raya parda, ácaros verdes, enfermedad del virus del mosaico, sana o desconocida.

En este Colab se muestra cómo hacer lo siguiente::

  • Cargar el modelo https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 del TensorFlow Hub

  • Cargar el conjunto de datos de la mandioca de los conjuntos de datos de TensorFlow (TFDS)

  • Clasificar imágenes de hojas de mandioca en 4 categorías de enfermedades distintas, o como sana o desconocida.

  • Evaluar la exactitud del clasificador y observe cuán sólido es el modelo cuando se lo aplica a imágenes fuera de su ámbito.

Importaciones y preparación

!pip install matplotlib==3.2.2
import numpy as np import matplotlib.pyplot as plt import tensorflow as tf import tensorflow_datasets as tfds import tensorflow_hub as hub
#@title Helper function for displaying examples def plot(examples, predictions=None): # Get the images, labels, and optionally predictions images = examples['image'] labels = examples['label'] batch_size = len(images) if predictions is None: predictions = batch_size * [None] # Configure the layout of the grid x = np.ceil(np.sqrt(batch_size)) y = np.ceil(batch_size / x) fig = plt.figure(figsize=(x * 6, y * 7)) for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)): # Render the image ax = fig.add_subplot(x, y, i+1) ax.imshow(image, aspect='auto') ax.grid(False) ax.set_xticks([]) ax.set_yticks([]) # Display the label and optionally prediction x_label = 'Label: ' + name_map[class_names[label]] if prediction is not None: x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label ax.xaxis.label.set_color('green' if label == prediction else 'red') ax.set_xlabel(x_label) plt.show()

Conjunto de datos

Carguemos el conjunto de datos de mandioca (cassava) de TFDS

dataset, info = tfds.load('cassava', with_info=True)

Echemos un vistazo a la información del conjunto de datos para entender mejor. Observemos datos como los de la descripción, citas e información sobre cuántos ejemplos hay disponibles

info

El conjunto de datos sobre mandiocas tiene imágenes de mandiocas con 4 enfermedades diferentes y también de hojas sanas de mandioca. El modelo puede predecir todas estas clases, y también la sexta clase "desconocida" cuando el modelo no se siente confiado con su predicción.

# Extend the cassava dataset classes with 'unknown' class_names = info.features['label'].names + ['unknown'] # Map the class names to human readable names name_map = dict( cmd='Mosaic Disease', cbb='Bacterial Blight', cgm='Green Mite', cbsd='Brown Streak Disease', healthy='Healthy', unknown='Unknown') print(len(class_names), 'classes:') print(class_names) print([name_map[name] for name in class_names])

Antes de introducir los datos en el modelo, debemos hacer algo de preprocesamiento. El modelo espera imágenes de 224 × 224 con valores de canales RGB en [0, 1]. Normalicemos y ajustemos los tamaños de las imágenes.

def preprocess_fn(data): image = data['image'] # Normalize [0, 255] to [0, 1] image = tf.cast(image, tf.float32) image = image / 255. # Resize the images to 224 x 224 image = tf.image.resize(image, (224, 224)) data['image'] = image return data

Observemos algunos pocos ejemplos del conjunto de datos

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator() examples = next(batch) plot(examples)

Modelo

Carguemos el clasificador de Hub, obtengamos algunas predicciones y veamos otras del modelo en algunos ejemplos.

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2') probabilities = classifier(examples['image']) predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

Evaluación y solidez

Midamos la exactitud de nuestro clasificador en un conjunto de datos separado. También podemos observar la solidez del modelo, mediante la evaluación del desempeño en un conjunto de datos que no sea el de mandiocas. Para los conjuntos de datos de plantas como iNaturalist o frijoles, el modelo, casi siempre, debería devolver desconocida.

#@title Parameters {run: "auto"} DATASET = 'cassava' #@param {type:"string"} ['cassava', 'beans', 'i_naturalist2017'] DATASET_SPLIT = 'test' #@param {type:"string"} ['train', 'test', 'validation'] BATCH_SIZE = 32 #@param {type:"integer"} MAX_EXAMPLES = 1000 #@param {type:"integer"}
def label_to_unknown_fn(data): data['label'] = 5 # Override label to unknown. return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets. ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES) dataset_description = DATASET if DATASET != 'cassava': ds = ds.map(label_to_unknown_fn) dataset_description += ' (labels mapped to unknown)' ds = ds.batch(BATCH_SIZE) # Calculate the accuracy of the model metric = tf.keras.metrics.Accuracy() for examples in ds: probabilities = classifier(examples['image']) predictions = tf.math.argmax(probabilities, axis=-1) labels = examples['label'] metric.update_state(labels, predictions) print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))

Más información