Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/check_inits.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import collections
17
import importlib.util
18
import os
19
import re
20
from pathlib import Path
21
22
23
PATH_TO_TRANSFORMERS = "src/transformers"
24
25
26
# Matches is_xxx_available()
27
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
28
# Catches a one-line _import_struct = {xxx}
29
_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
30
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
31
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
32
# Catches a line if not is_foo_available
33
_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
34
# Catches a line _import_struct["bla"].append("foo")
35
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
36
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
37
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
38
# Catches a line with an object between quotes and a comma: "MyModel",
39
_re_quote_object = re.compile('^\s+"([^"]+)",')
40
# Catches a line with objects between brackets only: ["foo", "bar"],
41
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
42
# Catches a line with from foo import bar, bla, boo
43
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
44
# Catches a line with try:
45
_re_try = re.compile(r"^\s*try:")
46
# Catches a line with else:
47
_re_else = re.compile(r"^\s*else:")
48
49
50
def find_backend(line):
51
"""Find one (or multiple) backend in a code line of the init."""
52
if _re_test_backend.search(line) is None:
53
return None
54
backends = [b[0] for b in _re_backend.findall(line)]
55
backends.sort()
56
return "_and_".join(backends)
57
58
59
def parse_init(init_file):
60
"""
61
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
62
defined
63
"""
64
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
65
lines = f.readlines()
66
67
line_index = 0
68
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
69
line_index += 1
70
71
# If this is a traditional init, just return.
72
if line_index >= len(lines):
73
return None
74
75
# First grab the objects without a specific backend in _import_structure
76
objects = []
77
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
78
line = lines[line_index]
79
# If we have everything on a single line, let's deal with it.
80
if _re_one_line_import_struct.search(line):
81
content = _re_one_line_import_struct.search(line).groups()[0]
82
imports = re.findall("\[([^\]]+)\]", content)
83
for imp in imports:
84
objects.extend([obj[1:-1] for obj in imp.split(", ")])
85
line_index += 1
86
continue
87
single_line_import_search = _re_import_struct_key_value.search(line)
88
if single_line_import_search is not None:
89
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
90
objects.extend(imports)
91
elif line.startswith(" " * 8 + '"'):
92
objects.append(line[9:-3])
93
line_index += 1
94
95
import_dict_objects = {"none": objects}
96
# Let's continue with backend-specific objects in _import_structure
97
while not lines[line_index].startswith("if TYPE_CHECKING"):
98
# If the line is an if not is_backend_available, we grab all objects associated.
99
backend = find_backend(lines[line_index])
100
# Check if the backend declaration is inside a try block:
101
if _re_try.search(lines[line_index - 1]) is None:
102
backend = None
103
104
if backend is not None:
105
line_index += 1
106
107
# Scroll until we hit the else block of try-except-else
108
while _re_else.search(lines[line_index]) is None:
109
line_index += 1
110
111
line_index += 1
112
113
objects = []
114
# Until we unindent, add backend objects to the list
115
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
116
line = lines[line_index]
117
if _re_import_struct_add_one.search(line) is not None:
118
objects.append(_re_import_struct_add_one.search(line).groups()[0])
119
elif _re_import_struct_add_many.search(line) is not None:
120
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
121
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
122
objects.extend(imports)
123
elif _re_between_brackets.search(line) is not None:
124
imports = _re_between_brackets.search(line).groups()[0].split(", ")
125
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
126
objects.extend(imports)
127
elif _re_quote_object.search(line) is not None:
128
objects.append(_re_quote_object.search(line).groups()[0])
129
elif line.startswith(" " * 8 + '"'):
130
objects.append(line[9:-3])
131
elif line.startswith(" " * 12 + '"'):
132
objects.append(line[13:-3])
133
line_index += 1
134
135
import_dict_objects[backend] = objects
136
else:
137
line_index += 1
138
139
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
140
objects = []
141
while (
142
line_index < len(lines)
143
and find_backend(lines[line_index]) is None
144
and not lines[line_index].startswith("else")
145
):
146
line = lines[line_index]
147
single_line_import_search = _re_import.search(line)
148
if single_line_import_search is not None:
149
objects.extend(single_line_import_search.groups()[0].split(", "))
150
elif line.startswith(" " * 8):
151
objects.append(line[8:-2])
152
line_index += 1
153
154
type_hint_objects = {"none": objects}
155
# Let's continue with backend-specific objects
156
while line_index < len(lines):
157
# If the line is an if is_backend_available, we grab all objects associated.
158
backend = find_backend(lines[line_index])
159
# Check if the backend declaration is inside a try block:
160
if _re_try.search(lines[line_index - 1]) is None:
161
backend = None
162
163
if backend is not None:
164
line_index += 1
165
166
# Scroll until we hit the else block of try-except-else
167
while _re_else.search(lines[line_index]) is None:
168
line_index += 1
169
170
line_index += 1
171
172
objects = []
173
# Until we unindent, add backend objects to the list
174
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
175
line = lines[line_index]
176
single_line_import_search = _re_import.search(line)
177
if single_line_import_search is not None:
178
objects.extend(single_line_import_search.groups()[0].split(", "))
179
elif line.startswith(" " * 12):
180
objects.append(line[12:-2])
181
line_index += 1
182
183
type_hint_objects[backend] = objects
184
else:
185
line_index += 1
186
187
return import_dict_objects, type_hint_objects
188
189
190
def analyze_results(import_dict_objects, type_hint_objects):
191
"""
192
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
193
"""
194
195
def find_duplicates(seq):
196
return [k for k, v in collections.Counter(seq).items() if v > 1]
197
198
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
199
return ["Both sides of the init do not have the same backends!"]
200
201
errors = []
202
for key in import_dict_objects.keys():
203
duplicate_imports = find_duplicates(import_dict_objects[key])
204
if duplicate_imports:
205
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
206
duplicate_type_hints = find_duplicates(type_hint_objects[key])
207
if duplicate_type_hints:
208
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
209
210
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
211
name = "base imports" if key == "none" else f"{key} backend"
212
errors.append(f"Differences for {name}:")
213
for a in type_hint_objects[key]:
214
if a not in import_dict_objects[key]:
215
errors.append(f" {a} in TYPE_HINT but not in _import_structure.")
216
for a in import_dict_objects[key]:
217
if a not in type_hint_objects[key]:
218
errors.append(f" {a} in _import_structure but not in TYPE_HINT.")
219
return errors
220
221
222
def check_all_inits():
223
"""
224
Check all inits in the transformers repo and raise an error if at least one does not define the same objects in
225
both halves.
226
"""
227
failures = []
228
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
229
if "__init__.py" in files:
230
fname = os.path.join(root, "__init__.py")
231
objects = parse_init(fname)
232
if objects is not None:
233
errors = analyze_results(*objects)
234
if len(errors) > 0:
235
errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}"
236
failures.append("\n".join(errors))
237
if len(failures) > 0:
238
raise ValueError("\n\n".join(failures))
239
240
241
def get_transformers_submodules():
242
"""
243
Returns the list of Transformers submodules.
244
"""
245
submodules = []
246
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
247
for folder in directories:
248
# Ignore private modules
249
if folder.startswith("_"):
250
directories.remove(folder)
251
continue
252
# Ignore leftovers from branches (empty folders apart from pycache)
253
if len(list((Path(path) / folder).glob("*.py"))) == 0:
254
continue
255
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
256
submodule = short_path.replace(os.path.sep, ".")
257
submodules.append(submodule)
258
for fname in files:
259
if fname == "__init__.py":
260
continue
261
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
262
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
263
if len(submodule.split(".")) == 1:
264
submodules.append(submodule)
265
return submodules
266
267
268
IGNORE_SUBMODULES = [
269
"convert_pytorch_checkpoint_to_tf2",
270
"modeling_flax_pytorch_utils",
271
]
272
273
274
def check_submodules():
275
# This is to make sure the transformers module imported is the one in the repo.
276
spec = importlib.util.spec_from_file_location(
277
"transformers",
278
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
279
submodule_search_locations=[PATH_TO_TRANSFORMERS],
280
)
281
transformers = spec.loader.load_module()
282
283
module_not_registered = [
284
module
285
for module in get_transformers_submodules()
286
if module not in IGNORE_SUBMODULES and module not in transformers._import_structure.keys()
287
]
288
if len(module_not_registered) > 0:
289
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
290
raise ValueError(
291
"The following submodules are not properly registered in the main init of Transformers:\n"
292
f"{list_of_modules}\n"
293
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
294
)
295
296
297
if __name__ == "__main__":
298
check_all_inits()
299
check_submodules()
300
301