import argparse
import os
import re
PATH_TO_TRANSFORMERS = "src/diffusers"
_re_indent = re.compile(r"^(\s*)\S")
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
def get_indent(line):
"""Returns the indent in `line`."""
search = _re_indent.search(line)
return "" if search is None else search.groups()[0]
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
"""
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
"""
index = 0
lines = code.split("\n")
if start_prompt is not None:
while not lines[index].startswith(start_prompt):
index += 1
blocks = ["\n".join(lines[:index])]
else:
blocks = []
current_block = [lines[index]]
index += 1
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
current_block.append(lines[index])
blocks.append("\n".join(current_block))
if index < len(lines) - 1:
current_block = [lines[index + 1]]
index += 1
else:
current_block = []
else:
blocks.append("\n".join(current_block))
current_block = [lines[index]]
else:
current_block.append(lines[index])
index += 1
if len(current_block) > 0:
blocks.append("\n".join(current_block))
if end_prompt is not None and index < len(lines):
blocks.append("\n".join(lines[index:]))
return blocks
def ignore_underscore(key):
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
def _inner(x):
return key(x).lower().replace("_", "")
return _inner
def sort_objects(objects, key=None):
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
def noop(x):
return x
if key is None:
key = noop
constants = [obj for obj in objects if key(obj).isupper()]
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
functions = [obj for obj in objects if not key(obj)[0].isupper()]
key1 = ignore_underscore(key)
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
def sort_objects_in_import(import_statement):
"""
Return the same `import_statement` but with objects properly sorted.
"""
def _replace(match):
imports = match.groups()[0]
if "," not in imports:
return f"[{imports}]"
keys = [part.strip().replace('"', "") for part in imports.split(",")]
if len(keys[-1]) == 0:
keys = keys[:-1]
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
lines = import_statement.split("\n")
if len(lines) > 3:
idx = 2 if lines[1].strip() == "[" else 1
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
elif len(lines) == 3:
if _re_bracket_content.search(lines[1]) is not None:
lines[1] = _re_bracket_content.sub(_replace, lines[1])
else:
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
if len(keys[-1]) == 0:
keys = keys[:-1]
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
return "\n".join(lines)
else:
import_statement = _re_bracket_content.sub(_replace, import_statement)
return import_statement
def sort_imports(file, check_only=True):
"""
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
"""
with open(file, "r") as f:
code = f.read()
if "_import_structure" not in code:
return
main_blocks = split_code_in_indented_blocks(
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
)
for block_idx in range(1, len(main_blocks) - 1):
block = main_blocks[block_idx]
block_lines = block.split("\n")
line_idx = 0
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
if "import dummy" in block_lines[line_idx]:
line_idx = len(block_lines)
else:
line_idx += 1
if line_idx >= len(block_lines):
continue
internal_block_code = "\n".join(block_lines[line_idx:-1])
indent = get_indent(block_lines[1])
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
count = 0
reordered_blocks = []
for i in range(len(internal_blocks)):
if keys[i] is None:
reordered_blocks.append(internal_blocks[i])
else:
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
reordered_blocks.append(block)
count += 1
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reordered_blocks + [block_lines[-1]])
if code != "\n".join(main_blocks):
if check_only:
return True
else:
print(f"Overwriting {file}.")
with open(file, "w") as f:
f.write("\n".join(main_blocks))
def sort_imports_in_all_inits(check_only=True):
failures = []
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
if "__init__.py" in files:
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
if result:
failures = [os.path.join(root, "__init__.py")]
if len(failures) > 0:
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_imports_in_all_inits(check_only=args.check_only)