Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/check_dummies.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import argparse
17
import os
18
import re
19
20
21
# All paths are set with the intent you should run this script from the root of the repo with the command
22
# python utils/check_dummies.py
23
PATH_TO_DIFFUSERS = "src/diffusers"
24
25
# Matches is_xxx_available()
26
_re_backend = re.compile(r"is\_([a-z_]*)_available\(\)")
27
# Matches from xxx import bla
28
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
29
30
31
DUMMY_CONSTANT = """
32
{0} = None
33
"""
34
35
DUMMY_CLASS = """
36
class {0}(metaclass=DummyObject):
37
_backends = {1}
38
39
def __init__(self, *args, **kwargs):
40
requires_backends(self, {1})
41
42
@classmethod
43
def from_config(cls, *args, **kwargs):
44
requires_backends(cls, {1})
45
46
@classmethod
47
def from_pretrained(cls, *args, **kwargs):
48
requires_backends(cls, {1})
49
"""
50
51
52
DUMMY_FUNCTION = """
53
def {0}(*args, **kwargs):
54
requires_backends({0}, {1})
55
"""
56
57
58
def find_backend(line):
59
"""Find one (or multiple) backend in a code line of the init."""
60
backends = _re_backend.findall(line)
61
if len(backends) == 0:
62
return None
63
64
return "_and_".join(backends)
65
66
67
def read_init():
68
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
69
with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
70
lines = f.readlines()
71
72
# Get to the point we do the actual imports for type checking
73
line_index = 0
74
backend_specific_objects = {}
75
# Go through the end of the file
76
while line_index < len(lines):
77
# If the line contains is_backend_available, we grab all objects associated with the `else` block
78
backend = find_backend(lines[line_index])
79
if backend is not None:
80
while not lines[line_index].startswith("else:"):
81
line_index += 1
82
line_index += 1
83
objects = []
84
# Until we unindent, add backend objects to the list
85
while line_index < len(lines) and len(lines[line_index]) > 1:
86
line = lines[line_index]
87
single_line_import_search = _re_single_line_import.search(line)
88
if single_line_import_search is not None:
89
objects.extend(single_line_import_search.groups()[0].split(", "))
90
elif line.startswith(" " * 8):
91
objects.append(line[8:-2])
92
line_index += 1
93
94
if len(objects) > 0:
95
backend_specific_objects[backend] = objects
96
else:
97
line_index += 1
98
99
return backend_specific_objects
100
101
102
def create_dummy_object(name, backend_name):
103
"""Create the code for the dummy object corresponding to `name`."""
104
if name.isupper():
105
return DUMMY_CONSTANT.format(name)
106
elif name.islower():
107
return DUMMY_FUNCTION.format(name, backend_name)
108
else:
109
return DUMMY_CLASS.format(name, backend_name)
110
111
112
def create_dummy_files(backend_specific_objects=None):
113
"""Create the content of the dummy files."""
114
if backend_specific_objects is None:
115
backend_specific_objects = read_init()
116
# For special correspondence backend to module name as used in the function requires_modulename
117
dummy_files = {}
118
119
for backend, objects in backend_specific_objects.items():
120
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
121
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
122
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
123
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
124
dummy_files[backend] = dummy_file
125
126
return dummy_files
127
128
129
def check_dummies(overwrite=False):
130
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
131
dummy_files = create_dummy_files()
132
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
133
short_names = {"torch": "pt"}
134
135
# Locate actual dummy modules and read their content.
136
path = os.path.join(PATH_TO_DIFFUSERS, "utils")
137
dummy_file_paths = {
138
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
139
for backend in dummy_files.keys()
140
}
141
142
actual_dummies = {}
143
for backend, file_path in dummy_file_paths.items():
144
if os.path.isfile(file_path):
145
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
146
actual_dummies[backend] = f.read()
147
else:
148
actual_dummies[backend] = ""
149
150
for backend in dummy_files.keys():
151
if dummy_files[backend] != actual_dummies[backend]:
152
if overwrite:
153
print(
154
f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
155
"__init__ has new objects."
156
)
157
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
158
f.write(dummy_files[backend])
159
else:
160
raise ValueError(
161
"The main __init__ has objects that are not present in "
162
f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
163
"to fix this."
164
)
165
166
167
if __name__ == "__main__":
168
parser = argparse.ArgumentParser()
169
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
170
args = parser.parse_args()
171
172
check_dummies(args.fix_and_overwrite)
173
174