Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/custom_init_isort.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import argparse
17
import os
18
import re
19
20
21
PATH_TO_TRANSFORMERS = "src/diffusers"
22
23
# Pattern that looks at the indentation in a line.
24
_re_indent = re.compile(r"^(\s*)\S")
25
# Pattern that matches `"key":" and puts `key` in group 0.
26
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
27
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
28
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
29
# Pattern that matches `"key",` and puts `key` in group 0.
30
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
31
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
32
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
33
34
35
def get_indent(line):
36
"""Returns the indent in `line`."""
37
search = _re_indent.search(line)
38
return "" if search is None else search.groups()[0]
39
40
41
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
42
"""
43
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
44
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
45
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
46
"""
47
# Let's split the code into lines and move to start_index.
48
index = 0
49
lines = code.split("\n")
50
if start_prompt is not None:
51
while not lines[index].startswith(start_prompt):
52
index += 1
53
blocks = ["\n".join(lines[:index])]
54
else:
55
blocks = []
56
57
# We split into blocks until we get to the `end_prompt` (or the end of the block).
58
current_block = [lines[index]]
59
index += 1
60
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
61
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
62
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
63
current_block.append(lines[index])
64
blocks.append("\n".join(current_block))
65
if index < len(lines) - 1:
66
current_block = [lines[index + 1]]
67
index += 1
68
else:
69
current_block = []
70
else:
71
blocks.append("\n".join(current_block))
72
current_block = [lines[index]]
73
else:
74
current_block.append(lines[index])
75
index += 1
76
77
# Adds current block if it's nonempty.
78
if len(current_block) > 0:
79
blocks.append("\n".join(current_block))
80
81
# Add final block after end_prompt if provided.
82
if end_prompt is not None and index < len(lines):
83
blocks.append("\n".join(lines[index:]))
84
85
return blocks
86
87
88
def ignore_underscore(key):
89
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
90
91
def _inner(x):
92
return key(x).lower().replace("_", "")
93
94
return _inner
95
96
97
def sort_objects(objects, key=None):
98
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
99
100
# If no key is provided, we use a noop.
101
def noop(x):
102
return x
103
104
if key is None:
105
key = noop
106
# Constants are all uppercase, they go first.
107
constants = [obj for obj in objects if key(obj).isupper()]
108
# Classes are not all uppercase but start with a capital, they go second.
109
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
110
# Functions begin with a lowercase, they go last.
111
functions = [obj for obj in objects if not key(obj)[0].isupper()]
112
113
key1 = ignore_underscore(key)
114
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
115
116
117
def sort_objects_in_import(import_statement):
118
"""
119
Return the same `import_statement` but with objects properly sorted.
120
"""
121
122
# This inner function sort imports between [ ].
123
def _replace(match):
124
imports = match.groups()[0]
125
if "," not in imports:
126
return f"[{imports}]"
127
keys = [part.strip().replace('"', "") for part in imports.split(",")]
128
# We will have a final empty element if the line finished with a comma.
129
if len(keys[-1]) == 0:
130
keys = keys[:-1]
131
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
132
133
lines = import_statement.split("\n")
134
if len(lines) > 3:
135
# Here we have to sort internal imports that are on several lines (one per name):
136
# key: [
137
# "object1",
138
# "object2",
139
# ...
140
# ]
141
142
# We may have to ignore one or two lines on each side.
143
idx = 2 if lines[1].strip() == "[" else 1
144
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
145
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
146
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
147
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
148
elif len(lines) == 3:
149
# Here we have to sort internal imports that are on one separate line:
150
# key: [
151
# "object1", "object2", ...
152
# ]
153
if _re_bracket_content.search(lines[1]) is not None:
154
lines[1] = _re_bracket_content.sub(_replace, lines[1])
155
else:
156
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
157
# We will have a final empty element if the line finished with a comma.
158
if len(keys[-1]) == 0:
159
keys = keys[:-1]
160
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
161
return "\n".join(lines)
162
else:
163
# Finally we have to deal with imports fitting on one line
164
import_statement = _re_bracket_content.sub(_replace, import_statement)
165
return import_statement
166
167
168
def sort_imports(file, check_only=True):
169
"""
170
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
171
"""
172
with open(file, "r") as f:
173
code = f.read()
174
175
if "_import_structure" not in code:
176
return
177
178
# Blocks of indent level 0
179
main_blocks = split_code_in_indented_blocks(
180
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
181
)
182
183
# We ignore block 0 (everything until start_prompt) and the last block (everything after end_prompt).
184
for block_idx in range(1, len(main_blocks) - 1):
185
# Check if the block contains some `_import_structure`s thingy to sort.
186
block = main_blocks[block_idx]
187
block_lines = block.split("\n")
188
189
# Get to the start of the imports.
190
line_idx = 0
191
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
192
# Skip dummy import blocks
193
if "import dummy" in block_lines[line_idx]:
194
line_idx = len(block_lines)
195
else:
196
line_idx += 1
197
if line_idx >= len(block_lines):
198
continue
199
200
# Ignore beginning and last line: they don't contain anything.
201
internal_block_code = "\n".join(block_lines[line_idx:-1])
202
indent = get_indent(block_lines[1])
203
# Slit the internal block into blocks of indent level 1.
204
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
205
# We have two categories of import key: list or _import_structure[key].append/extend
206
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
207
# Grab the keys, but there is a trap: some lines are empty or just comments.
208
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
209
# We only sort the lines with a key.
210
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
211
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
212
213
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
214
count = 0
215
reordered_blocks = []
216
for i in range(len(internal_blocks)):
217
if keys[i] is None:
218
reordered_blocks.append(internal_blocks[i])
219
else:
220
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
221
reordered_blocks.append(block)
222
count += 1
223
224
# And we put our main block back together with its first and last line.
225
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reordered_blocks + [block_lines[-1]])
226
227
if code != "\n".join(main_blocks):
228
if check_only:
229
return True
230
else:
231
print(f"Overwriting {file}.")
232
with open(file, "w") as f:
233
f.write("\n".join(main_blocks))
234
235
236
def sort_imports_in_all_inits(check_only=True):
237
failures = []
238
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
239
if "__init__.py" in files:
240
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
241
if result:
242
failures = [os.path.join(root, "__init__.py")]
243
if len(failures) > 0:
244
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
245
246
247
if __name__ == "__main__":
248
parser = argparse.ArgumentParser()
249
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
250
args = parser.parse_args()
251
252
sort_imports_in_all_inits(check_only=args.check_only)
253
254