Path: blob/master/FaceMaskOverlay/lib/core/function.py
3443 views
# ------------------------------------------------------------------------------1# Copyright (c) Microsoft2# Licensed under the MIT License.3# Created by Tianheng Cheng([email protected])4# ------------------------------------------------------------------------------56from __future__ import absolute_import7from __future__ import division8from __future__ import print_function910import time11import logging1213import torch14import numpy as np1516from .evaluation import decode_preds, compute_nme1718logger = logging.getLogger(__name__)192021class AverageMeter(object):22"""Computes and stores the average and current value"""23def __init__(self):24self.val = 025self.avg = 026self.sum = 027self.count = 028self.reset()2930def reset(self):31self.val = 032self.avg = 033self.sum = 034self.count = 03536def update(self, val, n=1):37self.val = val38self.sum += val * n39self.count += n40self.avg = self.sum / self.count414243def train(config, train_loader, model, critertion, optimizer,44epoch, writer_dict):4546batch_time = AverageMeter()47data_time = AverageMeter()48losses = AverageMeter()4950model.train()51nme_count = 052nme_batch_sum = 05354end = time.time()5556for i, (inp, target, meta) in enumerate(train_loader):57# measure data time58data_time.update(time.time()-end)5960# compute the output61output = model(inp)62target = target.cuda(non_blocking=True)6364loss = critertion(output, target)6566# NME67score_map = output.data.cpu()68preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])6970nme_batch = compute_nme(preds, meta)71nme_batch_sum = nme_batch_sum + np.sum(nme_batch)72nme_count = nme_count + preds.size(0)7374# optimize75optimizer.zero_grad()76loss.backward()77optimizer.step()7879losses.update(loss.item(), inp.size(0))8081batch_time.update(time.time()-end)82if i % config.PRINT_FREQ == 0:83msg = 'Epoch: [{0}][{1}/{2}]\t' \84'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \85'Speed {speed:.1f} samples/s\t' \86'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \87'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(88epoch, i, len(train_loader), batch_time=batch_time,89speed=inp.size(0)/batch_time.val,90data_time=data_time, loss=losses)91logger.info(msg)9293if writer_dict:94writer = writer_dict['writer']95global_steps = writer_dict['train_global_steps']96writer.add_scalar('train_loss', losses.val, global_steps)97writer_dict['train_global_steps'] = global_steps + 19899end = time.time()100nme = nme_batch_sum / nme_count101msg = 'Train Epoch {} time:{:.4f} loss:{:.4f} nme:{:.4f}'\102.format(epoch, batch_time.avg, losses.avg, nme)103logger.info(msg)104105106def validate(config, val_loader, model, criterion, epoch, writer_dict):107batch_time = AverageMeter()108data_time = AverageMeter()109110losses = AverageMeter()111112num_classes = config.MODEL.NUM_JOINTS113predictions = torch.zeros((len(val_loader.dataset), num_classes, 2))114115model.eval()116117nme_count = 0118nme_batch_sum = 0119count_failure_008 = 0120count_failure_010 = 0121end = time.time()122123with torch.no_grad():124for i, (inp, target, meta) in enumerate(val_loader):125data_time.update(time.time() - end)126output = model(inp)127target = target.cuda(non_blocking=True)128129score_map = output.data.cpu()130# loss131loss = criterion(output, target)132133preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])134# NME135nme_temp = compute_nme(preds, meta)136# Failure Rate under different threshold137failure_008 = (nme_temp > 0.08).sum()138failure_010 = (nme_temp > 0.10).sum()139count_failure_008 += failure_008140count_failure_010 += failure_010141142nme_batch_sum += np.sum(nme_temp)143nme_count = nme_count + preds.size(0)144for n in range(score_map.size(0)):145predictions[meta['index'][n], :, :] = preds[n, :, :]146147losses.update(loss.item(), inp.size(0))148149# measure elapsed time150batch_time.update(time.time() - end)151end = time.time()152153nme = nme_batch_sum / nme_count154failure_008_rate = count_failure_008 / nme_count155failure_010_rate = count_failure_010 / nme_count156157msg = 'Test Epoch {} time:{:.4f} loss:{:.4f} nme:{:.4f} [008]:{:.4f} ' \158'[010]:{:.4f}'.format(epoch, batch_time.avg, losses.avg, nme,159failure_008_rate, failure_010_rate)160logger.info(msg)161162if writer_dict:163writer = writer_dict['writer']164global_steps = writer_dict['valid_global_steps']165writer.add_scalar('valid_loss', losses.avg, global_steps)166writer.add_scalar('valid_nme', nme, global_steps)167writer_dict['valid_global_steps'] = global_steps + 1168169return nme, predictions170171172def inference(config, data_loader, model):173batch_time = AverageMeter()174data_time = AverageMeter()175losses = AverageMeter()176177num_classes = config.MODEL.NUM_JOINTS178predictions = torch.zeros((len(data_loader.dataset), num_classes, 2))179180model.eval()181182nme_count = 0183nme_batch_sum = 0184count_failure_008 = 0185count_failure_010 = 0186end = time.time()187188with torch.no_grad():189for i, (inp, target, meta) in enumerate(data_loader):190data_time.update(time.time() - end)191output = model(inp)192score_map = output.data.cpu()193preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])194195# NME196nme_temp = compute_nme(preds, meta)197198failure_008 = (nme_temp > 0.08).sum()199failure_010 = (nme_temp > 0.10).sum()200count_failure_008 += failure_008201count_failure_010 += failure_010202203nme_batch_sum += np.sum(nme_temp)204nme_count = nme_count + preds.size(0)205for n in range(score_map.size(0)):206predictions[meta['index'][n], :, :] = preds[n, :, :]207208# measure elapsed time209batch_time.update(time.time() - end)210end = time.time()211212nme = nme_batch_sum / nme_count213failure_008_rate = count_failure_008 / nme_count214failure_010_rate = count_failure_010 / nme_count215216msg = 'Test Results time:{:.4f} loss:{:.4f} nme:{:.4f} [008]:{:.4f} ' \217'[010]:{:.4f}'.format(batch_time.avg, losses.avg, nme,218failure_008_rate, failure_010_rate)219logger.info(msg)220221return nme, predictions222223224225226227