Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/utils/loggers/wandb/wandb_utils.py
Views: 475
"""Utilities and tools for tracking runs with Weights & Biases."""12import logging3import os4import sys5from contextlib import contextmanager6from pathlib import Path7from typing import Dict89import yaml10from tqdm import tqdm1112FILE = Path(__file__).resolve()13ROOT = FILE.parents[3] # YOLOv5 root directory14if str(ROOT) not in sys.path:15sys.path.append(str(ROOT)) # add ROOT to PATH1617from utils.datasets import LoadImagesAndLabels, img2label_paths18from utils.general import LOGGER, check_dataset, check_file1920try:21import wandb2223assert hasattr(wandb, '__version__') # verify package import not local dir24except (ImportError, AssertionError):25wandb = None2627RANK = int(os.getenv('RANK', -1))28WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'293031def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):32return from_string[len(prefix):]333435def check_wandb_config_file(data_config_file):36wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path37if Path(wandb_config).is_file():38return wandb_config39return data_config_file404142def check_wandb_dataset(data_file):43is_trainset_wandb_artifact = False44is_valset_wandb_artifact = False45if check_file(data_file) and data_file.endswith('.yaml'):46with open(data_file, errors='ignore') as f:47data_dict = yaml.safe_load(f)48is_trainset_wandb_artifact = (isinstance(data_dict['train'], str) and49data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX))50is_valset_wandb_artifact = (isinstance(data_dict['val'], str) and51data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX))52if is_trainset_wandb_artifact or is_valset_wandb_artifact:53return data_dict54else:55return check_dataset(data_file)565758def get_run_info(run_path):59run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))60run_id = run_path.stem61project = run_path.parent.stem62entity = run_path.parent.parent.stem63model_artifact_name = 'run_' + run_id + '_model'64return entity, project, run_id, model_artifact_name656667def check_wandb_resume(opt):68process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None69if isinstance(opt.resume, str):70if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):71if RANK not in [-1, 0]: # For resuming DDP runs72entity, project, run_id, model_artifact_name = get_run_info(opt.resume)73api = wandb.Api()74artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')75modeldir = artifact.download()76opt.weights = str(Path(modeldir) / "last.pt")77return True78return None798081def process_wandb_config_ddp_mode(opt):82with open(check_file(opt.data), errors='ignore') as f:83data_dict = yaml.safe_load(f) # data dict84train_dir, val_dir = None, None85if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):86api = wandb.Api()87train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)88train_dir = train_artifact.download()89train_path = Path(train_dir) / 'data/images/'90data_dict['train'] = str(train_path)9192if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):93api = wandb.Api()94val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)95val_dir = val_artifact.download()96val_path = Path(val_dir) / 'data/images/'97data_dict['val'] = str(val_path)98if train_dir or val_dir:99ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')100with open(ddp_data_path, 'w') as f:101yaml.safe_dump(data_dict, f)102opt.data = ddp_data_path103104105class WandbLogger():106"""Log training runs, datasets, models, and predictions to Weights & Biases.107108This logger sends information to W&B at wandb.ai. By default, this information109includes hyperparameters, system configuration and metrics, model metrics,110and basic data metrics and analyses.111112By providing additional command line arguments to train.py, datasets,113models and predictions can also be logged.114115For more on how this logger is used, see the Weights & Biases documentation:116https://docs.wandb.com/guides/integrations/yolov5117"""118119def __init__(self, opt, run_id=None, job_type='Training'):120"""121- Initialize WandbLogger instance122- Upload dataset if opt.upload_dataset is True123- Setup trainig processes if job_type is 'Training'124125arguments:126opt (namespace) -- Commandline arguments for this run127run_id (str) -- Run ID of W&B run to be resumed128job_type (str) -- To set the job_type for this run129130"""131# Pre-training routine --132self.job_type = job_type133self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run134self.val_artifact, self.train_artifact = None, None135self.train_artifact_path, self.val_artifact_path = None, None136self.result_artifact = None137self.val_table, self.result_table = None, None138self.bbox_media_panel_images = []139self.val_table_path_map = None140self.max_imgs_to_log = 16141self.wandb_artifact_data_dict = None142self.data_dict = None143# It's more elegant to stick to 1 wandb.init call,144# but useful config data is overwritten in the WandbLogger's wandb.init call145if isinstance(opt.resume, str): # checks resume from artifact146if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):147entity, project, run_id, model_artifact_name = get_run_info(opt.resume)148model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name149assert wandb, 'install wandb to resume wandb runs'150# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config151self.wandb_run = wandb.init(id=run_id,152project=project,153entity=entity,154resume='allow',155allow_val_change=True)156opt.resume = model_artifact_name157elif self.wandb:158self.wandb_run = wandb.init(config=opt,159resume="allow",160project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,161entity=opt.entity,162name=opt.name if opt.name != 'exp' else None,163job_type=job_type,164id=run_id,165allow_val_change=True) if not wandb.run else wandb.run166if self.wandb_run:167if self.job_type == 'Training':168if opt.upload_dataset:169if not opt.resume:170self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)171172if opt.resume:173# resume from artifact174if isinstance(opt.resume, str) and opt.resume.startswith(WANDB_ARTIFACT_PREFIX):175self.data_dict = dict(self.wandb_run.config.data_dict)176else: # local resume177self.data_dict = check_wandb_dataset(opt.data)178else:179self.data_dict = check_wandb_dataset(opt.data)180self.wandb_artifact_data_dict = self.wandb_artifact_data_dict or self.data_dict181182# write data_dict to config. useful for resuming from artifacts. Do this only when not resuming.183self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict},184allow_val_change=True)185self.setup_training(opt)186187if self.job_type == 'Dataset Creation':188self.wandb_run.config.update({"upload_dataset": True})189self.data_dict = self.check_and_upload_dataset(opt)190191def check_and_upload_dataset(self, opt):192"""193Check if the dataset format is compatible and upload it as W&B artifact194195arguments:196opt (namespace)-- Commandline arguments for current run197198returns:199Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.200"""201assert wandb, 'Install wandb to upload dataset'202config_path = self.log_dataset_artifact(opt.data,203opt.single_cls,204'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)205with open(config_path, errors='ignore') as f:206wandb_data_dict = yaml.safe_load(f)207return wandb_data_dict208209def setup_training(self, opt):210"""211Setup the necessary processes for training YOLO models:212- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX213- Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded214- Setup log_dict, initialize bbox_interval215216arguments:217opt (namespace) -- commandline arguments for this run218219"""220self.log_dict, self.current_epoch = {}, 0221self.bbox_interval = opt.bbox_interval222if isinstance(opt.resume, str):223modeldir, _ = self.download_model_artifact(opt)224if modeldir:225self.weights = Path(modeldir) / "last.pt"226config = self.wandb_run.config227opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(228self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \229config.hyp230data_dict = self.data_dict231if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download232self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),233opt.artifact_alias)234self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),235opt.artifact_alias)236237if self.train_artifact_path is not None:238train_path = Path(self.train_artifact_path) / 'data/images/'239data_dict['train'] = str(train_path)240if self.val_artifact_path is not None:241val_path = Path(self.val_artifact_path) / 'data/images/'242data_dict['val'] = str(val_path)243244if self.val_artifact is not None:245self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")246columns = ["epoch", "id", "ground truth", "prediction"]247columns.extend(self.data_dict['names'])248self.result_table = wandb.Table(columns)249self.val_table = self.val_artifact.get("val")250if self.val_table_path_map is None:251self.map_val_table_path()252if opt.bbox_interval == -1:253self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1254train_from_artifact = self.train_artifact_path is not None and self.val_artifact_path is not None255# Update the the data_dict to point to local artifacts dir256if train_from_artifact:257self.data_dict = data_dict258259def download_dataset_artifact(self, path, alias):260"""261download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX262263arguments:264path -- path of the dataset to be used for training265alias (str)-- alias of the artifact to be download/used for training266267returns:268(str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset269is found otherwise returns (None, None)270"""271if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):272artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)273dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/"))274assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"275datadir = dataset_artifact.download()276return datadir, dataset_artifact277return None, None278279def download_model_artifact(self, opt):280"""281download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX282283arguments:284opt (namespace) -- Commandline arguments for this run285"""286if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):287model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")288assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'289modeldir = model_artifact.download()290epochs_trained = model_artifact.metadata.get('epochs_trained')291total_epochs = model_artifact.metadata.get('total_epochs')292is_finished = total_epochs is None293assert not is_finished, 'training is finished, can only resume incomplete runs.'294return modeldir, model_artifact295return None, None296297def log_model(self, path, opt, epoch, fitness_score, best_model=False):298"""299Log the model checkpoint as W&B artifact300301arguments:302path (Path) -- Path of directory containing the checkpoints303opt (namespace) -- Command line arguments for this run304epoch (int) -- Current epoch number305fitness_score (float) -- fitness score for current epoch306best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.307"""308model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={309'original_url': str(path),310'epochs_trained': epoch + 1,311'save period': opt.save_period,312'project': opt.project,313'total_epochs': opt.epochs,314'fitness_score': fitness_score315})316model_artifact.add_file(str(path / 'last.pt'), name='last.pt')317wandb.log_artifact(model_artifact,318aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])319LOGGER.info(f"Saving model artifact on epoch {epoch + 1}")320321def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):322"""323Log the dataset as W&B artifact and return the new data file with W&B links324325arguments:326data_file (str) -- the .yaml file with information about the dataset like - path, classes etc.327single_class (boolean) -- train multi-class data as single-class328project (str) -- project name. Used to construct the artifact path329overwrite_config (boolean) -- overwrites the data.yaml file if set to true otherwise creates a new330file with _wandb postfix. Eg -> data_wandb.yaml331332returns:333the new .yaml file with artifact links. it can be used to start training directly from artifacts334"""335upload_dataset = self.wandb_run.config.upload_dataset336log_val_only = isinstance(upload_dataset, str) and upload_dataset == 'val'337self.data_dict = check_dataset(data_file) # parse and check338data = dict(self.data_dict)339nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])340names = {k: v for k, v in enumerate(names)} # to index dictionary341342# log train set343if not log_val_only:344self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(345data['train'], rect=True, batch_size=1), names, name='train') if data.get('train') else None346if data.get('train'):347data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')348349self.val_artifact = self.create_dataset_table(LoadImagesAndLabels(350data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None351if data.get('val'):352data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')353354path = Path(data_file)355# create a _wandb.yaml file with artifacts links if both train and test set are logged356if not log_val_only:357path = (path.stem if overwrite_config else path.stem + '_wandb') + '.yaml' # updated data.yaml path358path = Path('data') / path359data.pop('download', None)360data.pop('path', None)361with open(path, 'w') as f:362yaml.safe_dump(data, f)363LOGGER.info(f"Created dataset config file {path}")364365if self.job_type == 'Training': # builds correct artifact pipeline graph366if not log_val_only:367self.wandb_run.log_artifact(368self.train_artifact) # calling use_artifact downloads the dataset. NOT NEEDED!369self.wandb_run.use_artifact(self.val_artifact)370self.val_artifact.wait()371self.val_table = self.val_artifact.get('val')372self.map_val_table_path()373else:374self.wandb_run.log_artifact(self.train_artifact)375self.wandb_run.log_artifact(self.val_artifact)376return path377378def map_val_table_path(self):379"""380Map the validation dataset Table like name of file -> it's id in the W&B Table.381Useful for - referencing artifacts for evaluation.382"""383self.val_table_path_map = {}384LOGGER.info("Mapping dataset")385for i, data in enumerate(tqdm(self.val_table.data)):386self.val_table_path_map[data[3]] = data[0]387388def create_dataset_table(self, dataset: LoadImagesAndLabels, class_to_id: Dict[int, str], name: str = 'dataset'):389"""390Create and return W&B artifact containing W&B Table of the dataset.391392arguments:393dataset -- instance of LoadImagesAndLabels class used to iterate over the data to build Table394class_to_id -- hash map that maps class ids to labels395name -- name of the artifact396397returns:398dataset artifact to be logged or used399"""400# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging401artifact = wandb.Artifact(name=name, type="dataset")402img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None403img_files = tqdm(dataset.img_files) if not img_files else img_files404for img_file in img_files:405if Path(img_file).is_dir():406artifact.add_dir(img_file, name='data/images')407labels_path = 'labels'.join(dataset.path.rsplit('images', 1))408artifact.add_dir(labels_path, name='data/labels')409else:410artifact.add_file(img_file, name='data/images/' + Path(img_file).name)411label_file = Path(img2label_paths([img_file])[0])412artifact.add_file(str(label_file),413name='data/labels/' + label_file.name) if label_file.exists() else None414table = wandb.Table(columns=["id", "train_image", "Classes", "name"])415class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])416for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):417box_data, img_classes = [], {}418for cls, *xywh in labels[:, 1:].tolist():419cls = int(cls)420box_data.append({"position": {"middle": [xywh[0], xywh[1]], "width": xywh[2], "height": xywh[3]},421"class_id": cls,422"box_caption": "%s" % (class_to_id[cls])})423img_classes[cls] = class_to_id[cls]424boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space425table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), list(img_classes.values()),426Path(paths).name)427artifact.add(table, name)428return artifact429430def log_training_progress(self, predn, path, names):431"""432Build evaluation Table. Uses reference from validation dataset table.433434arguments:435predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class]436path (str): local path of the current evaluation image437names (dict(int, str)): hash map that maps class ids to labels438"""439class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])440box_data = []441avg_conf_per_class = [0] * len(self.data_dict['names'])442pred_class_count = {}443for *xyxy, conf, cls in predn.tolist():444if conf >= 0.25:445cls = int(cls)446box_data.append(447{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},448"class_id": cls,449"box_caption": f"{names[cls]} {conf:.3f}",450"scores": {"class_score": conf},451"domain": "pixel"})452avg_conf_per_class[cls] += conf453454if cls in pred_class_count:455pred_class_count[cls] += 1456else:457pred_class_count[cls] = 1458459for pred_class in pred_class_count.keys():460avg_conf_per_class[pred_class] = avg_conf_per_class[pred_class] / pred_class_count[pred_class]461462boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space463id = self.val_table_path_map[Path(path).name]464self.result_table.add_data(self.current_epoch,465id,466self.val_table.data[id][1],467wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),468*avg_conf_per_class469)470471def val_one_image(self, pred, predn, path, names, im):472"""473Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel474475arguments:476pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]477predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class]478path (str): local path of the current evaluation image479"""480if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact481self.log_training_progress(predn, path, names)482483if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0:484if self.current_epoch % self.bbox_interval == 0:485box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},486"class_id": int(cls),487"box_caption": f"{names[cls]} {conf:.3f}",488"scores": {"class_score": conf},489"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]490boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space491self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))492493def log(self, log_dict):494"""495save the metrics to the logging dictionary496497arguments:498log_dict (Dict) -- metrics/media to be logged in current step499"""500if self.wandb_run:501for key, value in log_dict.items():502self.log_dict[key] = value503504def end_epoch(self, best_result=False):505"""506commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.507508arguments:509best_result (boolean): Boolean representing if the result of this evaluation is best or not510"""511if self.wandb_run:512with all_logging_disabled():513if self.bbox_media_panel_images:514self.log_dict["BoundingBoxDebugger"] = self.bbox_media_panel_images515try:516wandb.log(self.log_dict)517except BaseException as e:518LOGGER.info(519f"An error occurred in wandb logger. The training will proceed without interruption. More info\n{e}")520self.wandb_run.finish()521self.wandb_run = None522523self.log_dict = {}524self.bbox_media_panel_images = []525if self.result_artifact:526self.result_artifact.add(self.result_table, 'result')527wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),528('best' if best_result else '')])529530wandb.log({"evaluation": self.result_table})531columns = ["epoch", "id", "ground truth", "prediction"]532columns.extend(self.data_dict['names'])533self.result_table = wandb.Table(columns)534self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")535536def finish_run(self):537"""538Log metrics if any and finish the current W&B run539"""540if self.wandb_run:541if self.log_dict:542with all_logging_disabled():543wandb.log(self.log_dict)544wandb.run.finish()545546547@contextmanager548def all_logging_disabled(highest_level=logging.CRITICAL):549""" source - https://gist.github.com/simon-weber/7853144550A context manager that will prevent any logging messages triggered during the body from being processed.551:param highest_level: the maximum logging level in use.552This would only need to be changed if a custom level greater than CRITICAL is defined.553"""554previous_level = logging.root.manager.disable555logging.disable(highest_level)556try:557yield558finally:559logging.disable(previous_level)560561562