Path: blob/master/site/zh-cn/hub/tutorials/tf2_object_detection.ipynb
25118 views
Copyright 2020 The TensorFlow Hub Authors.
Licensed under the Apache License, Version 2.0 (the "License");
#@title Copyright 2020 The TensorFlow Hub Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ==============================================================================
TensorFlow Hub 目标检测 Colab
欢迎使用 TensorFlow Hub 目标检测 Colab!此笔记本将指导您完成在图像上运行“开箱即用”的目标检测模型的各个步骤。
导入和设置
让我们从基础导入开始。
# This Colab requires a recent numpy version. !pip install numpy==1.24.3 !pip install protobuf==3.20.3
import os import pathlib import matplotlib import matplotlib.pyplot as plt import io import scipy.misc import numpy as np from six import BytesIO from PIL import Image, ImageDraw, ImageFont from six.moves.urllib.request import urlopen import tensorflow as tf import tensorflow_hub as hub tf.get_logger().setLevel('ERROR')
实用工具
运行以下单元来创建一些稍后需要使用的实用工具:
用于加载图像的辅助方法
模型名称与 TF Hub 句柄之间的映射
包含 COCO 2017 数据集人体关键点的元组列表。包含关键点的模型需要此列表。
# @title Run this!! def load_image_into_numpy_array(path): """Load an image from file into a numpy array. Puts image into numpy array to feed into tensorflow graph. Note that by convention we put it into a numpy array with shape (height, width, channels), where channels=3 for RGB. Args: path: the file path to the image Returns: uint8 numpy array with shape (img_height, img_width, 3) """ image = None if(path.startswith('http')): response = urlopen(path) image_data = response.read() image_data = BytesIO(image_data) image = Image.open(image_data) else: image_data = tf.io.gfile.GFile(path, 'rb').read() image = Image.open(BytesIO(image_data)) (im_width, im_height) = image.size return np.array(image.getdata()).reshape( (1, im_height, im_width, 3)).astype(np.uint8) ALL_MODELS = { 'CenterNet HourGlass104 512x512' : 'https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1', 'CenterNet HourGlass104 Keypoints 512x512' : 'https://tfhub.dev/tensorflow/centernet/hourglass_512x512_kpts/1', 'CenterNet HourGlass104 1024x1024' : 'https://tfhub.dev/tensorflow/centernet/hourglass_1024x1024/1', 'CenterNet HourGlass104 Keypoints 1024x1024' : 'https://tfhub.dev/tensorflow/centernet/hourglass_1024x1024_kpts/1', 'CenterNet Resnet50 V1 FPN 512x512' : 'https://tfhub.dev/tensorflow/centernet/resnet50v1_fpn_512x512/1', 'CenterNet Resnet50 V1 FPN Keypoints 512x512' : 'https://tfhub.dev/tensorflow/centernet/resnet50v1_fpn_512x512_kpts/1', 'CenterNet Resnet101 V1 FPN 512x512' : 'https://tfhub.dev/tensorflow/centernet/resnet101v1_fpn_512x512/1', 'CenterNet Resnet50 V2 512x512' : 'https://tfhub.dev/tensorflow/centernet/resnet50v2_512x512/1', 'CenterNet Resnet50 V2 Keypoints 512x512' : 'https://tfhub.dev/tensorflow/centernet/resnet50v2_512x512_kpts/1', 'EfficientDet D0 512x512' : 'https://tfhub.dev/tensorflow/efficientdet/d0/1', 'EfficientDet D1 640x640' : 'https://tfhub.dev/tensorflow/efficientdet/d1/1', 'EfficientDet D2 768x768' : 'https://tfhub.dev/tensorflow/efficientdet/d2/1', 'EfficientDet D3 896x896' : 'https://tfhub.dev/tensorflow/efficientdet/d3/1', 'EfficientDet D4 1024x1024' : 'https://tfhub.dev/tensorflow/efficientdet/d4/1', 'EfficientDet D5 1280x1280' : 'https://tfhub.dev/tensorflow/efficientdet/d5/1', 'EfficientDet D6 1280x1280' : 'https://tfhub.dev/tensorflow/efficientdet/d6/1', 'EfficientDet D7 1536x1536' : 'https://tfhub.dev/tensorflow/efficientdet/d7/1', 'SSD MobileNet v2 320x320' : 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/2', 'SSD MobileNet V1 FPN 640x640' : 'https://tfhub.dev/tensorflow/ssd_mobilenet_v1/fpn_640x640/1', 'SSD MobileNet V2 FPNLite 320x320' : 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/fpnlite_320x320/1', 'SSD MobileNet V2 FPNLite 640x640' : 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/fpnlite_640x640/1', 'SSD ResNet50 V1 FPN 640x640 (RetinaNet50)' : 'https://tfhub.dev/tensorflow/retinanet/resnet50_v1_fpn_640x640/1', 'SSD ResNet50 V1 FPN 1024x1024 (RetinaNet50)' : 'https://tfhub.dev/tensorflow/retinanet/resnet50_v1_fpn_1024x1024/1', 'SSD ResNet101 V1 FPN 640x640 (RetinaNet101)' : 'https://tfhub.dev/tensorflow/retinanet/resnet101_v1_fpn_640x640/1', 'SSD ResNet101 V1 FPN 1024x1024 (RetinaNet101)' : 'https://tfhub.dev/tensorflow/retinanet/resnet101_v1_fpn_1024x1024/1', 'SSD ResNet152 V1 FPN 640x640 (RetinaNet152)' : 'https://tfhub.dev/tensorflow/retinanet/resnet152_v1_fpn_640x640/1', 'SSD ResNet152 V1 FPN 1024x1024 (RetinaNet152)' : 'https://tfhub.dev/tensorflow/retinanet/resnet152_v1_fpn_1024x1024/1', 'Faster R-CNN ResNet50 V1 640x640' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet50_v1_640x640/1', 'Faster R-CNN ResNet50 V1 1024x1024' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet50_v1_1024x1024/1', 'Faster R-CNN ResNet50 V1 800x1333' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet50_v1_800x1333/1', 'Faster R-CNN ResNet101 V1 640x640' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet101_v1_640x640/1', 'Faster R-CNN ResNet101 V1 1024x1024' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet101_v1_1024x1024/1', 'Faster R-CNN ResNet101 V1 800x1333' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet101_v1_800x1333/1', 'Faster R-CNN ResNet152 V1 640x640' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet152_v1_640x640/1', 'Faster R-CNN ResNet152 V1 1024x1024' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet152_v1_1024x1024/1', 'Faster R-CNN ResNet152 V1 800x1333' : 'https://tfhub.dev/tensorflow/faster_rcnn/resnet152_v1_800x1333/1', 'Faster R-CNN Inception ResNet V2 640x640' : 'https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1', 'Faster R-CNN Inception ResNet V2 1024x1024' : 'https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_1024x1024/1', 'Mask R-CNN Inception ResNet V2 1024x1024' : 'https://tfhub.dev/tensorflow/mask_rcnn/inception_resnet_v2_1024x1024/1' } IMAGES_FOR_TEST = { 'Beach' : 'models/research/object_detection/test_images/image2.jpg', 'Dogs' : 'models/research/object_detection/test_images/image1.jpg', # By Heiko Gorski, Source: https://commons.wikimedia.org/wiki/File:Naxos_Taverna.jpg 'Naxos Taverna' : 'https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg', # Source: https://commons.wikimedia.org/wiki/File:The_Coleoptera_of_the_British_islands_(Plate_125)_(8592917784).jpg 'Beatles' : 'https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg', # By Américo Toledano, Source: https://commons.wikimedia.org/wiki/File:Biblioteca_Maim%C3%B3nides,_Campus_Universitario_de_Rabanales_007.jpg 'Phones' : 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg/1024px-Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg', # Source: https://commons.wikimedia.org/wiki/File:The_smaller_British_birds_(8053836633).jpg 'Birds' : 'https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053836633%29.jpg', } COCO17_HUMAN_POSE_KEYPOINTS = [(0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (7, 9), (6, 8), (8, 10), (5, 6), (5, 11), (6, 12), (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)]
可视化工具
为了可视化具有适当检测框、关键点和分割的图像,我们将使用 TensorFlow Object Detection API。要安装该 API,我们将克隆仓库。
# Clone the tensorflow models repository !git clone --depth 1 https://github.com/tensorflow/models
安装 Object Detection API
%%bash sudo apt install -y protobuf-compiler cd models/research/ protoc object_detection/protos/*.proto --python_out=. cp object_detection/packages/tf2/setup.py . python -m pip install .
现在,我们可以导入稍后需要使用的依赖项
from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as viz_utils from object_detection.utils import ops as utils_ops %matplotlib inline
加载标签映射数据(用于绘图)。
标签可将相应的索引编号映射至类别名称,这样一来,当我们的卷积网络预测 5
时,我们就能获知预测结果对应于 airplane
。在这里,我们使用内部效用函数,但您也可以使用其他方式,只要能够返回将整数映射至适当字符串标签的字典即可。
为了简单起见,我们将从加载 Object Detection API 代码的仓库中加载
PATH_TO_LABELS = './models/research/object_detection/data/mscoco_label_map.pbtxt' category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
构建检测模型并加载预训练模型权重
在这里,我们将选择要使用哪种目标检测模型。选择架构即可自动加载模型。如果要更改模型以便稍后尝试其他架构,只需更改下一个单元并执行后续单元即可。
提示:如果要阅读有关所选模型的更多详细信息,您可以点击链接(模型句柄)并阅读 TF Hub 上的其他文档。在您选择模型后,我们将打印句柄,使其更加方便。
#@title Model Selection { display-mode: "form", run: "auto" } model_display_name = 'CenterNet HourGlass104 Keypoints 512x512' # @param ['CenterNet HourGlass104 512x512','CenterNet HourGlass104 Keypoints 512x512','CenterNet HourGlass104 1024x1024','CenterNet HourGlass104 Keypoints 1024x1024','CenterNet Resnet50 V1 FPN 512x512','CenterNet Resnet50 V1 FPN Keypoints 512x512','CenterNet Resnet101 V1 FPN 512x512','CenterNet Resnet50 V2 512x512','CenterNet Resnet50 V2 Keypoints 512x512','EfficientDet D0 512x512','EfficientDet D1 640x640','EfficientDet D2 768x768','EfficientDet D3 896x896','EfficientDet D4 1024x1024','EfficientDet D5 1280x1280','EfficientDet D6 1280x1280','EfficientDet D7 1536x1536','SSD MobileNet v2 320x320','SSD MobileNet V1 FPN 640x640','SSD MobileNet V2 FPNLite 320x320','SSD MobileNet V2 FPNLite 640x640','SSD ResNet50 V1 FPN 640x640 (RetinaNet50)','SSD ResNet50 V1 FPN 1024x1024 (RetinaNet50)','SSD ResNet101 V1 FPN 640x640 (RetinaNet101)','SSD ResNet101 V1 FPN 1024x1024 (RetinaNet101)','SSD ResNet152 V1 FPN 640x640 (RetinaNet152)','SSD ResNet152 V1 FPN 1024x1024 (RetinaNet152)','Faster R-CNN ResNet50 V1 640x640','Faster R-CNN ResNet50 V1 1024x1024','Faster R-CNN ResNet50 V1 800x1333','Faster R-CNN ResNet101 V1 640x640','Faster R-CNN ResNet101 V1 1024x1024','Faster R-CNN ResNet101 V1 800x1333','Faster R-CNN ResNet152 V1 640x640','Faster R-CNN ResNet152 V1 1024x1024','Faster R-CNN ResNet152 V1 800x1333','Faster R-CNN Inception ResNet V2 640x640','Faster R-CNN Inception ResNet V2 1024x1024','Mask R-CNN Inception ResNet V2 1024x1024'] model_handle = ALL_MODELS[model_display_name] print('Selected model:'+ model_display_name) print('Model Handle at TensorFlow Hub: {}'.format(model_handle))
从 TensorFlow Hub 加载所选模型
在这里,我们只需要所选模型句柄并使用 Tensorflow Hub 库将其加载到内存中。
print('loading model...') hub_model = hub.load(model_handle) print('model loaded!')
加载图像
让我们尝试在简单的图像上运行模型。为此,我们提供了一个测试图像列表。
如有兴趣,您可以进行以下简单尝试:
尝试在您自己的图像上运行推断,只需将其上传到 Colab 并采用与下方单元相同的方式加载即可。
修改一些输入图像,然后查看检测是否仍然有效。这里可以尝试的一些简单修改包括水平翻转图像或将图像转换为灰度图(请注意,输入图像仍需具有 3 个通道)。
注:使用带有 Alpha 通道的图像时,模型需要 3 通道图像,并将 Alpha 通道视为第 4 通道。
#@title Image Selection (don't forget to execute the cell!) { display-mode: "form"} selected_image = 'Beach' # @param ['Beach', 'Dogs', 'Naxos Taverna', 'Beatles', 'Phones', 'Birds'] flip_image_horizontally = False #@param {type:"boolean"} convert_image_to_grayscale = False #@param {type:"boolean"} image_path = IMAGES_FOR_TEST[selected_image] image_np = load_image_into_numpy_array(image_path) # Flip horizontally if(flip_image_horizontally): image_np[0] = np.fliplr(image_np[0]).copy() # Convert image to grayscale if(convert_image_to_grayscale): image_np[0] = np.tile( np.mean(image_np[0], 2, keepdims=True), (1, 1, 3)).astype(np.uint8) plt.figure(figsize=(24,32)) plt.imshow(image_np[0]) plt.show()
进行推断
要进行推断,我们只需调用加载的 TF Hub 模型。
您可以尝试以下操作:
打印
result['detection_boxes']
并尝试将框的位置与图像中的框相匹配。请注意,坐标需要以归一化形式(即区间 [0, 1] 内)给出。检查结果中存在的其他输出键。可以在模型文档页面(将浏览器指向之前打印的模型句柄)上查看完整的文档
# running inference results = hub_model(image_np) # different object detection models have additional results # all of them are explained in the documentation result = {key:value.numpy() for key,value in results.items()} print(result.keys())
可视化结果
在这里,我们将通过 TensorFlow Object Detection API 显示推断步骤中生成的正方形(如可用,还包括关键点)。
有关此方式的完整文档,请参阅此处
您可以在这里进行一些调整,例如将 min_score_thresh
设置为其他值(0 到 1)以支持更多检测或排除更多检测。
label_id_offset = 0 image_np_with_detections = image_np.copy() # Use keypoints if available in detections keypoints, keypoint_scores = None, None if 'detection_keypoints' in result: keypoints = result['detection_keypoints'][0] keypoint_scores = result['detection_keypoint_scores'][0] viz_utils.visualize_boxes_and_labels_on_image_array( image_np_with_detections[0], result['detection_boxes'][0], (result['detection_classes'][0] + label_id_offset).astype(int), result['detection_scores'][0], category_index, use_normalized_coordinates=True, max_boxes_to_draw=200, min_score_thresh=.30, agnostic_mode=False, keypoints=keypoints, keypoint_scores=keypoint_scores, keypoint_edges=COCO17_HUMAN_POSE_KEYPOINTS) plt.figure(figsize=(24,32)) plt.imshow(image_np_with_detections[0]) plt.show()
[可选]
Mask R-CNN 是众多可用目标检测模型中的一种,该模型的输出支持实例分割。
要对其进行可视化,我们将使用此前所用的方法,不同之处在于需要添加参数:instance_masks=output_dict.get('detection_masks_reframed', None)
# Handle models with masks: image_np_with_mask = image_np.copy() if 'detection_masks' in result: # we need to convert np.arrays to tensors detection_masks = tf.convert_to_tensor(result['detection_masks'][0]) detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0]) # Reframe the bbox mask to the image size. detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks( detection_masks, detection_boxes, image_np.shape[1], image_np.shape[2]) detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5, tf.uint8) result['detection_masks_reframed'] = detection_masks_reframed.numpy() viz_utils.visualize_boxes_and_labels_on_image_array( image_np_with_mask[0], result['detection_boxes'][0], (result['detection_classes'][0] + label_id_offset).astype(int), result['detection_scores'][0], category_index, use_normalized_coordinates=True, max_boxes_to_draw=200, min_score_thresh=.30, agnostic_mode=False, instance_masks=result.get('detection_masks_reframed', None), line_thickness=8) plt.figure(figsize=(24,32)) plt.imshow(image_np_with_mask[0]) plt.show()