Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/models/tf.py
Views: 475
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license1"""2TensorFlow, Keras and TFLite versions of YOLOv53Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/112745Usage:6$ python models/tf.py --weights yolov5s.pt78Export:9$ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs10"""1112import argparse13import sys14from copy import deepcopy15from pathlib import Path1617FILE = Path(__file__).resolve()18ROOT = FILE.parents[1] # YOLOv5 root directory19if str(ROOT) not in sys.path:20sys.path.append(str(ROOT)) # add ROOT to PATH21# ROOT = ROOT.relative_to(Path.cwd()) # relative2223import numpy as np24import tensorflow as tf25import torch26import torch.nn as nn27from tensorflow import keras2829from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad30from models.experimental import CrossConv, MixConv2d, attempt_load31from models.yolo import Detect32from utils.activations import SiLU33from utils.general import LOGGER, make_divisible, print_args343536class TFBN(keras.layers.Layer):37# TensorFlow BatchNormalization wrapper38def __init__(self, w=None):39super().__init__()40self.bn = keras.layers.BatchNormalization(41beta_initializer=keras.initializers.Constant(w.bias.numpy()),42gamma_initializer=keras.initializers.Constant(w.weight.numpy()),43moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),44moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),45epsilon=w.eps)4647def call(self, inputs):48return self.bn(inputs)495051class TFPad(keras.layers.Layer):52def __init__(self, pad):53super().__init__()54self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])5556def call(self, inputs):57return tf.pad(inputs, self.pad, mode='constant', constant_values=0)585960class TFConv(keras.layers.Layer):61# Standard convolution62def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):63# ch_in, ch_out, weights, kernel, stride, padding, groups64super().__init__()65assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"66assert isinstance(k, int), "Convolution with multiple kernels are not allowed."67# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)68# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch6970conv = keras.layers.Conv2D(71c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False if hasattr(w, 'bn') else True,72kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),73bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))74self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])75self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity7677# YOLOv5 activations78if isinstance(w.act, nn.LeakyReLU):79self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity80elif isinstance(w.act, nn.Hardswish):81self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity82elif isinstance(w.act, (nn.SiLU, SiLU)):83self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity84else:85raise Exception(f'no matching TensorFlow activation found for {w.act}')8687def call(self, inputs):88return self.act(self.bn(self.conv(inputs)))899091class TFFocus(keras.layers.Layer):92# Focus wh information into c-space93def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):94# ch_in, ch_out, kernel, stride, padding, groups95super().__init__()96self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)9798def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)99# inputs = inputs / 255 # normalize 0-255 to 0-1100return self.conv(tf.concat([inputs[:, ::2, ::2, :],101inputs[:, 1::2, ::2, :],102inputs[:, ::2, 1::2, :],103inputs[:, 1::2, 1::2, :]], 3))104105106class TFBottleneck(keras.layers.Layer):107# Standard bottleneck108def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion109super().__init__()110c_ = int(c2 * e) # hidden channels111self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)112self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)113self.add = shortcut and c1 == c2114115def call(self, inputs):116return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))117118119class TFConv2d(keras.layers.Layer):120# Substitution for PyTorch nn.Conv2D121def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):122super().__init__()123assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"124self.conv = keras.layers.Conv2D(125c2, k, s, 'VALID', use_bias=bias,126kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),127bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )128129def call(self, inputs):130return self.conv(inputs)131132133class TFBottleneckCSP(keras.layers.Layer):134# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks135def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):136# ch_in, ch_out, number, shortcut, groups, expansion137super().__init__()138c_ = int(c2 * e) # hidden channels139self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)140self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)141self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)142self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)143self.bn = TFBN(w.bn)144self.act = lambda x: keras.activations.relu(x, alpha=0.1)145self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])146147def call(self, inputs):148y1 = self.cv3(self.m(self.cv1(inputs)))149y2 = self.cv2(inputs)150return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))151152153class TFC3(keras.layers.Layer):154# CSP Bottleneck with 3 convolutions155def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):156# ch_in, ch_out, number, shortcut, groups, expansion157super().__init__()158c_ = int(c2 * e) # hidden channels159self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)160self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)161self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)162self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])163164def call(self, inputs):165return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))166167168class TFSPP(keras.layers.Layer):169# Spatial pyramid pooling layer used in YOLOv3-SPP170def __init__(self, c1, c2, k=(5, 9, 13), w=None):171super().__init__()172c_ = c1 // 2 # hidden channels173self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)174self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)175self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]176177def call(self, inputs):178x = self.cv1(inputs)179return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))180181182class TFSPPF(keras.layers.Layer):183# Spatial pyramid pooling-Fast layer184def __init__(self, c1, c2, k=5, w=None):185super().__init__()186c_ = c1 // 2 # hidden channels187self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)188self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)189self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')190191def call(self, inputs):192x = self.cv1(inputs)193y1 = self.m(x)194y2 = self.m(y1)195return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))196197198class TFDetect(keras.layers.Layer):199def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer200super().__init__()201self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)202self.nc = nc # number of classes203# self.no = nc + 5 # number of outputs per anchor204self.no = nc + 5 + 180 # number of outputs per anchor205self.nl = len(anchors) # number of detection layers206self.na = len(anchors[0]) // 2 # number of anchors207self.grid = [tf.zeros(1)] * self.nl # init grid208self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)209self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),210[self.nl, 1, -1, 1, 2])211self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]212self.training = False # set to False after building model213self.imgsz = imgsz214for i in range(self.nl):215ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]216self.grid[i] = self._make_grid(nx, ny)217218def call(self, inputs):219z = [] # inference output220x = []221for i in range(self.nl):222x.append(self.m[i](inputs[i]))223# x(bs,20,20,255) to x(bs,3,20,20,85)224ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]225x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])226227if not self.training: # inference228y = tf.sigmoid(x[i])229xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy230wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]231# Normalize xywh to 0-1 to reduce calibration error232xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)233wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)234y = tf.concat([xy, wh, y[..., 4:]], -1)235z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))236237return x if self.training else (tf.concat(z, 1), x)238239@staticmethod240def _make_grid(nx=20, ny=20):241# yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])242# return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()243xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))244return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)245246247class TFUpsample(keras.layers.Layer):248def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'249super().__init__()250assert scale_factor == 2, "scale_factor must be 2"251self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)252# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)253# with default arguments: align_corners=False, half_pixel_centers=False254# self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,255# size=(x.shape[1] * 2, x.shape[2] * 2))256257def call(self, inputs):258return self.upsample(inputs)259260261class TFConcat(keras.layers.Layer):262def __init__(self, dimension=1, w=None):263super().__init__()264assert dimension == 1, "convert only NCHW to NHWC concat"265self.d = 3266267def call(self, inputs):268return tf.concat(inputs, self.d)269270271def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)272LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")273anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']274na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors275no = na * (nc + 5 + 180) # number of outputs = anchors * (classes + 5)276277layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out278for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args279m_str = m280m = eval(m) if isinstance(m, str) else m # eval strings281for j, a in enumerate(args):282try:283args[j] = eval(a) if isinstance(a, str) else a # eval strings284except NameError:285pass286287n = max(round(n * gd), 1) if n > 1 else n # depth gain288if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:289c1, c2 = ch[f], args[0]290c2 = make_divisible(c2 * gw, 8) if c2 != no else c2291292args = [c1, c2, *args[1:]]293if m in [BottleneckCSP, C3]:294args.insert(2, n)295n = 1296elif m is nn.BatchNorm2d:297args = [ch[f]]298elif m is Concat:299c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)300elif m is Detect:301args.append([ch[x + 1] for x in f])302if isinstance(args[1], int): # number of anchors303args[1] = [list(range(args[1] * 2))] * len(f)304args.append(imgsz)305else:306c2 = ch[f]307308tf_m = eval('TF' + m_str.replace('nn.', ''))309m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \310else tf_m(*args, w=model.model[i]) # module311312torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module313t = str(m)[8:-2].replace('__main__.', '') # module type314np = sum(x.numel() for x in torch_m_.parameters()) # number params315m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params316LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print317save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist318layers.append(m_)319ch.append(c2)320return keras.Sequential(layers), sorted(save)321322323class TFModel:324def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes325super().__init__()326if isinstance(cfg, dict):327self.yaml = cfg # model dict328else: # is *.yaml329import yaml # for torch hub330self.yaml_file = Path(cfg).name331with open(cfg) as f:332self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict333334# Define model335if nc and nc != self.yaml['nc']:336LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")337self.yaml['nc'] = nc # override yaml value338self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)339340def predict(self, inputs, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,341conf_thres=0.25):342y = [] # outputs343x = inputs344for i, m in enumerate(self.model.layers):345if m.f != -1: # if not from previous layer346x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers347348x = m(x) # run349y.append(x if m.i in self.savelist else None) # save output350351# Add TensorFlow NMS352if tf_nms:353boxes = self._xywh2xyxy(x[0][..., :4])354probs = x[0][:, :, 4:5]355classes = x[0][:, :, 5:]356scores = probs * classes357if agnostic_nms:358nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)359return nms, x[1]360else:361boxes = tf.expand_dims(boxes, 2)362nms = tf.image.combined_non_max_suppression(363boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False)364return nms, x[1]365366return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]367# x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)368# xywh = x[..., :4] # x(6300,4) boxes369# conf = x[..., 4:5] # x(6300,1) confidences370# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes371# return tf.concat([conf, cls, xywh], 1)372373@staticmethod374def _xywh2xyxy(xywh):375# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right376x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)377return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)378379380class AgnosticNMS(keras.layers.Layer):381# TF Agnostic NMS382def call(self, input, topk_all, iou_thres, conf_thres):383# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450384return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,385fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),386name='agnostic_nms')387388@staticmethod389def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS390boxes, classes, scores = x391class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)392scores_inp = tf.reduce_max(scores, -1)393selected_inds = tf.image.non_max_suppression(394boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres)395selected_boxes = tf.gather(boxes, selected_inds)396padded_boxes = tf.pad(selected_boxes,397paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],398mode="CONSTANT", constant_values=0.0)399selected_scores = tf.gather(scores_inp, selected_inds)400padded_scores = tf.pad(selected_scores,401paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],402mode="CONSTANT", constant_values=-1.0)403selected_classes = tf.gather(class_inds, selected_inds)404padded_classes = tf.pad(selected_classes,405paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],406mode="CONSTANT", constant_values=-1.0)407valid_detections = tf.shape(selected_inds)[0]408return padded_boxes, padded_scores, padded_classes, valid_detections409410411def representative_dataset_gen(dataset, ncalib=100):412# Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays413for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):414input = np.transpose(img, [1, 2, 0])415input = np.expand_dims(input, axis=0).astype(np.float32)416input /= 255417yield [input]418if n >= ncalib:419break420421422def run(weights=ROOT / 'yolov5s.pt', # weights path423imgsz=(640, 640), # inference size h,w424batch_size=1, # batch size425dynamic=False, # dynamic batch size426):427# PyTorch model428im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image429model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)430y = model(im) # inference431model.info()432433# TensorFlow model434im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image435tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)436y = tf_model.predict(im) # inference437438# Keras model439im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)440keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))441keras_model.summary()442443LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')444445446def parse_opt():447parser = argparse.ArgumentParser()448parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')449parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')450parser.add_argument('--batch-size', type=int, default=1, help='batch size')451parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')452opt = parser.parse_args()453opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand454print_args(FILE.stem, opt)455return opt456457458def main(opt):459run(**vars(opt))460461462if __name__ == "__main__":463opt = parse_opt()464main(opt)465466467