Path: blob/master/modules/memmon.py
3055 views
import threading1import time2from collections import defaultdict34import torch567class MemUsageMonitor(threading.Thread):8run_flag = None9device = None10disabled = False11opts = None12data = None1314def __init__(self, name, device, opts):15threading.Thread.__init__(self)16self.name = name17self.device = device18self.opts = opts1920self.daemon = True21self.run_flag = threading.Event()22self.data = defaultdict(int)2324try:25self.cuda_mem_get_info()26torch.cuda.memory_stats(self.device)27except Exception as e: # AMD or whatever28print(f"Warning: caught exception '{e}', memory monitor disabled")29self.disabled = True3031def cuda_mem_get_info(self):32index = self.device.index if self.device.index is not None else torch.cuda.current_device()33return torch.cuda.mem_get_info(index)3435def run(self):36if self.disabled:37return3839while True:40self.run_flag.wait()4142torch.cuda.reset_peak_memory_stats()43self.data.clear()4445if self.opts.memmon_poll_rate <= 0:46self.run_flag.clear()47continue4849self.data["min_free"] = self.cuda_mem_get_info()[0]5051while self.run_flag.is_set():52free, total = self.cuda_mem_get_info()53self.data["min_free"] = min(self.data["min_free"], free)5455time.sleep(1 / self.opts.memmon_poll_rate)5657def dump_debug(self):58print(self, 'recorded data:')59for k, v in self.read().items():60print(k, -(v // -(1024 ** 2)))6162print(self, 'raw torch memory stats:')63tm = torch.cuda.memory_stats(self.device)64for k, v in tm.items():65if 'bytes' not in k:66continue67print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))6869print(torch.cuda.memory_summary())7071def monitor(self):72self.run_flag.set()7374def read(self):75if not self.disabled:76free, total = self.cuda_mem_get_info()77self.data["free"] = free78self.data["total"] = total7980torch_stats = torch.cuda.memory_stats(self.device)81self.data["active"] = torch_stats["active.all.current"]82self.data["active_peak"] = torch_stats["active_bytes.all.peak"]83self.data["reserved"] = torch_stats["reserved_bytes.all.current"]84self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]85self.data["system_peak"] = total - self.data["min_free"]8687return self.data8889def stop(self):90self.run_flag.clear()91return self.read()929394