Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/utils/torch_utils.py
Views: 475
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license1"""2PyTorch utils3"""45import datetime6import math7import os8import platform9import subprocess10import time11from contextlib import contextmanager12from copy import deepcopy13from pathlib import Path1415import torch16import torch.distributed as dist17import torch.nn as nn18import torch.nn.functional as F1920from utils.general import LOGGER2122try:23import thop # for FLOPs computation24except ImportError:25thop = None262728@contextmanager29def torch_distributed_zero_first(local_rank: int):30"""31Decorator to make all processes in distributed training wait for each local_master to do something.32"""33if local_rank not in [-1, 0]:34dist.barrier(device_ids=[local_rank])35yield36if local_rank == 0:37dist.barrier(device_ids=[0])383940def date_modified(path=__file__):41# return human-readable file modification date, i.e. '2021-3-26'42t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)43return f'{t.year}-{t.month}-{t.day}'444546def git_describe(path=Path(__file__).parent): # path must be a directory47# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe48s = f'git -C {path} describe --tags --long --always'49try:50return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]51except subprocess.CalledProcessError as e:52return '' # not a git repository535455def select_device(device='', batch_size=0, newline=True):56# device = 'cpu' or '0' or '0,1,2,3'57s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string58device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'59cpu = device == 'cpu'60if cpu:61os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False62elif device: # non-cpu device requested63os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable64assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability6566cuda = not cpu and torch.cuda.is_available()67if cuda:68devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,769n = len(devices) # device count70if n > 1 and batch_size > 0: # check batch_size is divisible by device_count71assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'72space = ' ' * (len(s) + 1)73for i, d in enumerate(devices):74p = torch.cuda.get_device_properties(i)75s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MiB)\n" # bytes to MB76else:77s += 'CPU\n'7879if not newline:80s = s.rstrip()81LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe82return torch.device('cuda:0' if cuda else 'cpu')838485def time_sync():86# pytorch-accurate time87if torch.cuda.is_available():88torch.cuda.synchronize()89return time.time()909192def profile(input, ops, n=10, device=None):93# YOLOv5 speed/memory/FLOPs profiler94#95# Usage:96# input = torch.randn(16, 3, 640, 640)97# m1 = lambda x: x * torch.sigmoid(x)98# m2 = nn.SiLU()99# profile(input, [m1, m2], n=100) # profile over 100 iterations100101results = []102device = device or select_device()103print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"104f"{'input':>24s}{'output':>24s}")105106for x in input if isinstance(input, list) else [input]:107x = x.to(device)108x.requires_grad = True109for m in ops if isinstance(ops, list) else [ops]:110m = m.to(device) if hasattr(m, 'to') else m # device111m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m112tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward113try:114flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs115except:116flops = 0117118try:119for _ in range(n):120t[0] = time_sync()121y = m(x)122t[1] = time_sync()123try:124_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()125t[2] = time_sync()126except Exception as e: # no backward method127# print(e) # for debug128t[2] = float('nan')129tf += (t[1] - t[0]) * 1000 / n # ms per op forward130tb += (t[2] - t[1]) * 1000 / n # ms per op backward131mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)132s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'133s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'134p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters135print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')136results.append([p, flops, mem, tf, tb, s_in, s_out])137except Exception as e:138print(e)139results.append(None)140torch.cuda.empty_cache()141return results142143144def is_parallel(model):145# Returns True if model is of type DP or DDP146return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)147148149def de_parallel(model):150# De-parallelize a model: returns single-GPU model if model is of type DP or DDP151return model.module if is_parallel(model) else model152153154def initialize_weights(model):155for m in model.modules():156t = type(m)157if t is nn.Conv2d:158pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')159elif t is nn.BatchNorm2d:160m.eps = 1e-3161m.momentum = 0.03162elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:163m.inplace = True164165166def find_modules(model, mclass=nn.Conv2d):167# Finds layer indices matching module class 'mclass'168return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]169170171def sparsity(model):172# Return global model sparsity173a, b = 0, 0174for p in model.parameters():175a += p.numel()176b += (p == 0).sum()177return b / a178179180def prune(model, amount=0.3):181# Prune model to requested global sparsity182import torch.nn.utils.prune as prune183print('Pruning model... ', end='')184for name, m in model.named_modules():185if isinstance(m, nn.Conv2d):186prune.l1_unstructured(m, name='weight', amount=amount) # prune187prune.remove(m, 'weight') # make permanent188print(' %.3g global sparsity' % sparsity(model))189190191def fuse_conv_and_bn(conv, bn):192# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/193fusedconv = nn.Conv2d(conv.in_channels,194conv.out_channels,195kernel_size=conv.kernel_size,196stride=conv.stride,197padding=conv.padding,198groups=conv.groups,199bias=True).requires_grad_(False).to(conv.weight.device)200201# prepare filters202w_conv = conv.weight.clone().view(conv.out_channels, -1)203w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))204fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))205206# prepare spatial bias207b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias208b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))209fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)210211return fusedconv212213214def model_info(model, verbose=False, img_size=640):215# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]216n_p = sum(x.numel() for x in model.parameters()) # number parameters217n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients218if verbose:219print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")220for i, (name, p) in enumerate(model.named_parameters()):221name = name.replace('module_list.', '')222print('%5g %40s %9s %12g %20s %10.3g %10.3g' %223(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))224225try: # FLOPs226from thop import profile227stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32228img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input229flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs230img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float231fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs232except (ImportError, Exception):233fs = ''234235LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")236237238def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)239# scales img(bs,3,y,x) by ratio constrained to gs-multiple240if ratio == 1.0:241return img242else:243h, w = img.shape[2:]244s = (int(h * ratio), int(w * ratio)) # new size245img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize246if not same_shape: # pad/crop img247h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))248return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean249250251def copy_attr(a, b, include=(), exclude=()):252# Copy attributes from b to a, options to only include [...] and to exclude [...]253for k, v in b.__dict__.items():254if (len(include) and k not in include) or k.startswith('_') or k in exclude:255continue256else:257setattr(a, k, v)258259260class EarlyStopping:261# YOLOv5 simple early stopper262def __init__(self, patience=30):263self.best_fitness = 0.0 # i.e. mAP264self.best_epoch = 0265self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop266self.possible_stop = False # possible stop may occur next epoch267268def __call__(self, epoch, fitness):269if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training270self.best_epoch = epoch271self.best_fitness = fitness272delta = epoch - self.best_epoch # epochs without improvement273self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch274stop = delta >= self.patience # stop training if patience exceeded275if stop:276LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '277f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'278f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '279f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')280return stop281282283class ModelEMA:284""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models285Keep a moving average of everything in the model state_dict (parameters and buffers).286This is intended to allow functionality like287https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage288A smoothed version of the weights is necessary for some training schemes to perform well.289This class is sensitive where it is initialized in the sequence of model init,290GPU assignment and distributed training wrappers.291"""292293def __init__(self, model, decay=0.9999, updates=0):294# Create EMA295self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA296# if next(model.parameters()).device.type != 'cpu':297# self.ema.half() # FP16 EMA298self.updates = updates # number of EMA updates299self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)300for p in self.ema.parameters():301p.requires_grad_(False)302303def update(self, model):304# Update EMA parameters305with torch.no_grad():306self.updates += 1307d = self.decay(self.updates)308309msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict310for k, v in self.ema.state_dict().items():311if v.dtype.is_floating_point:312v *= d313v += (1 - d) * msd[k].detach()314315def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):316# Update EMA attributes317copy_attr(self.ema, model, include, exclude)318319320