Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/patches.py
3055 views
1
from collections import defaultdict
2
3
4
def patch(key, obj, field, replacement):
5
"""Replaces a function in a module or a class.
6
7
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
8
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
9
10
Arguments:
11
key: identifying information for who is doing the replacement. You can use __name__.
12
obj: the module or the class
13
field: name of the function as a string
14
replacement: the new function
15
16
Returns:
17
the original function
18
"""
19
20
patch_key = (obj, field)
21
if patch_key in originals[key]:
22
raise RuntimeError(f"patch for {field} is already applied")
23
24
original_func = getattr(obj, field)
25
originals[key][patch_key] = original_func
26
27
setattr(obj, field, replacement)
28
29
return original_func
30
31
32
def undo(key, obj, field):
33
"""Undoes the peplacement by the patch().
34
35
If the function is not replaced, raises an exception.
36
37
Arguments:
38
key: identifying information for who is doing the replacement. You can use __name__.
39
obj: the module or the class
40
field: name of the function as a string
41
42
Returns:
43
Always None
44
"""
45
46
patch_key = (obj, field)
47
48
if patch_key not in originals[key]:
49
raise RuntimeError(f"there is no patch for {field} to undo")
50
51
original_func = originals[key].pop(patch_key)
52
setattr(obj, field, original_func)
53
54
return None
55
56
57
def original(key, obj, field):
58
"""Returns the original function for the patch created by the patch() function"""
59
patch_key = (obj, field)
60
61
return originals[key].get(patch_key, None)
62
63
64
originals = defaultdict(dict)
65
66