Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/memmon.py
3055 views
1
import threading
2
import time
3
from collections import defaultdict
4
5
import torch
6
7
8
class MemUsageMonitor(threading.Thread):
9
run_flag = None
10
device = None
11
disabled = False
12
opts = None
13
data = None
14
15
def __init__(self, name, device, opts):
16
threading.Thread.__init__(self)
17
self.name = name
18
self.device = device
19
self.opts = opts
20
21
self.daemon = True
22
self.run_flag = threading.Event()
23
self.data = defaultdict(int)
24
25
try:
26
self.cuda_mem_get_info()
27
torch.cuda.memory_stats(self.device)
28
except Exception as e: # AMD or whatever
29
print(f"Warning: caught exception '{e}', memory monitor disabled")
30
self.disabled = True
31
32
def cuda_mem_get_info(self):
33
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
34
return torch.cuda.mem_get_info(index)
35
36
def run(self):
37
if self.disabled:
38
return
39
40
while True:
41
self.run_flag.wait()
42
43
torch.cuda.reset_peak_memory_stats()
44
self.data.clear()
45
46
if self.opts.memmon_poll_rate <= 0:
47
self.run_flag.clear()
48
continue
49
50
self.data["min_free"] = self.cuda_mem_get_info()[0]
51
52
while self.run_flag.is_set():
53
free, total = self.cuda_mem_get_info()
54
self.data["min_free"] = min(self.data["min_free"], free)
55
56
time.sleep(1 / self.opts.memmon_poll_rate)
57
58
def dump_debug(self):
59
print(self, 'recorded data:')
60
for k, v in self.read().items():
61
print(k, -(v // -(1024 ** 2)))
62
63
print(self, 'raw torch memory stats:')
64
tm = torch.cuda.memory_stats(self.device)
65
for k, v in tm.items():
66
if 'bytes' not in k:
67
continue
68
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
69
70
print(torch.cuda.memory_summary())
71
72
def monitor(self):
73
self.run_flag.set()
74
75
def read(self):
76
if not self.disabled:
77
free, total = self.cuda_mem_get_info()
78
self.data["free"] = free
79
self.data["total"] = total
80
81
torch_stats = torch.cuda.memory_stats(self.device)
82
self.data["active"] = torch_stats["active.all.current"]
83
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
84
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
85
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
86
self.data["system_peak"] = total - self.data["min_free"]
87
88
return self.data
89
90
def stop(self):
91
self.run_flag.clear()
92
return self.read()
93
94