Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/check_copies.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import argparse
17
import glob
18
import importlib.util
19
import os
20
import re
21
22
import black
23
from doc_builder.style_doc import style_docstrings_in_code
24
25
26
# All paths are set with the intent you should run this script from the root of the repo with the command
27
# python utils/check_copies.py
28
DIFFUSERS_PATH = "src/diffusers"
29
REPO_PATH = "."
30
31
32
# This is to make sure the diffusers module imported is the one in the repo.
33
spec = importlib.util.spec_from_file_location(
34
"diffusers",
35
os.path.join(DIFFUSERS_PATH, "__init__.py"),
36
submodule_search_locations=[DIFFUSERS_PATH],
37
)
38
diffusers_module = spec.loader.load_module()
39
40
41
def _should_continue(line, indent):
42
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
43
44
45
def find_code_in_diffusers(object_name):
46
"""Find and return the code source code of `object_name`."""
47
parts = object_name.split(".")
48
i = 0
49
50
# First let's find the module where our object lives.
51
module = parts[i]
52
while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")):
53
i += 1
54
if i < len(parts):
55
module = os.path.join(module, parts[i])
56
if i >= len(parts):
57
raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.")
58
59
with open(os.path.join(DIFFUSERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
60
lines = f.readlines()
61
62
# Now let's find the class / func in the code!
63
indent = ""
64
line_index = 0
65
for name in parts[i + 1 :]:
66
while (
67
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
68
):
69
line_index += 1
70
indent += " "
71
line_index += 1
72
73
if line_index >= len(lines):
74
raise ValueError(f" {object_name} does not match any function or class in {module}.")
75
76
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
77
start_index = line_index
78
while line_index < len(lines) and _should_continue(lines[line_index], indent):
79
line_index += 1
80
# Clean up empty lines at the end (if any).
81
while len(lines[line_index - 1]) <= 1:
82
line_index -= 1
83
84
code_lines = lines[start_index:line_index]
85
return "".join(code_lines)
86
87
88
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)")
89
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
90
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
91
92
93
def get_indent(code):
94
lines = code.split("\n")
95
idx = 0
96
while idx < len(lines) and len(lines[idx]) == 0:
97
idx += 1
98
if idx < len(lines):
99
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
100
return ""
101
102
103
def blackify(code):
104
"""
105
Applies the black part of our `make style` command to `code`.
106
"""
107
has_indent = len(get_indent(code)) > 0
108
if has_indent:
109
code = f"class Bla:\n{code}"
110
mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=119, preview=True)
111
result = black.format_str(code, mode=mode)
112
result, _ = style_docstrings_in_code(result)
113
return result[len("class Bla:\n") :] if has_indent else result
114
115
116
def is_copy_consistent(filename, overwrite=False):
117
"""
118
Check if the code commented as a copy in `filename` matches the original.
119
Return the differences or overwrites the content depending on `overwrite`.
120
"""
121
with open(filename, "r", encoding="utf-8", newline="\n") as f:
122
lines = f.readlines()
123
diffs = []
124
line_index = 0
125
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
126
while line_index < len(lines):
127
search = _re_copy_warning.search(lines[line_index])
128
if search is None:
129
line_index += 1
130
continue
131
132
# There is some copied code here, let's retrieve the original.
133
indent, object_name, replace_pattern = search.groups()
134
theoretical_code = find_code_in_diffusers(object_name)
135
theoretical_indent = get_indent(theoretical_code)
136
137
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
138
indent = theoretical_indent
139
line_index = start_index
140
141
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
142
should_continue = True
143
while line_index < len(lines) and should_continue:
144
line_index += 1
145
if line_index >= len(lines):
146
break
147
line = lines[line_index]
148
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
149
# Clean up empty lines at the end (if any).
150
while len(lines[line_index - 1]) <= 1:
151
line_index -= 1
152
153
observed_code_lines = lines[start_index:line_index]
154
observed_code = "".join(observed_code_lines)
155
156
# Remove any nested `Copied from` comments to avoid circular copies
157
theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None]
158
theoretical_code = "\n".join(theoretical_code)
159
160
# Before comparing, use the `replace_pattern` on the original code.
161
if len(replace_pattern) > 0:
162
patterns = replace_pattern.replace("with", "").split(",")
163
patterns = [_re_replace_pattern.search(p) for p in patterns]
164
for pattern in patterns:
165
if pattern is None:
166
continue
167
obj1, obj2, option = pattern.groups()
168
theoretical_code = re.sub(obj1, obj2, theoretical_code)
169
if option.strip() == "all-casing":
170
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
171
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
172
173
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
174
# from the previous line
175
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
176
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
177
178
# Test for a diff and act accordingly.
179
if observed_code != theoretical_code:
180
diffs.append([object_name, start_index])
181
if overwrite:
182
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
183
line_index = start_index + 1
184
185
if overwrite and len(diffs) > 0:
186
# Warn the user a file has been modified.
187
print(f"Detected changes, rewriting {filename}.")
188
with open(filename, "w", encoding="utf-8", newline="\n") as f:
189
f.writelines(lines)
190
return diffs
191
192
193
def check_copies(overwrite: bool = False):
194
all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True)
195
diffs = []
196
for filename in all_files:
197
new_diffs = is_copy_consistent(filename, overwrite)
198
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
199
if not overwrite and len(diffs) > 0:
200
diff = "\n".join(diffs)
201
raise Exception(
202
"Found the following copy inconsistencies:\n"
203
+ diff
204
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
205
)
206
207
208
if __name__ == "__main__":
209
parser = argparse.ArgumentParser()
210
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
211
args = parser.parse_args()
212
213
check_copies(args.fix_and_overwrite)
214
215