Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/check_table.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import argparse
17
import collections
18
import importlib.util
19
import os
20
import re
21
22
23
# All paths are set with the intent you should run this script from the root of the repo with the command
24
# python utils/check_table.py
25
TRANSFORMERS_PATH = "src/diffusers"
26
PATH_TO_DOCS = "docs/source/en"
27
REPO_PATH = "."
28
29
30
def _find_text_in_file(filename, start_prompt, end_prompt):
31
"""
32
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
33
lines.
34
"""
35
with open(filename, "r", encoding="utf-8", newline="\n") as f:
36
lines = f.readlines()
37
# Find the start prompt.
38
start_index = 0
39
while not lines[start_index].startswith(start_prompt):
40
start_index += 1
41
start_index += 1
42
43
end_index = start_index
44
while not lines[end_index].startswith(end_prompt):
45
end_index += 1
46
end_index -= 1
47
48
while len(lines[start_index]) <= 1:
49
start_index += 1
50
while len(lines[end_index]) <= 1:
51
end_index -= 1
52
end_index += 1
53
return "".join(lines[start_index:end_index]), start_index, end_index, lines
54
55
56
# Add here suffixes that are used to identify models, separated by |
57
ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration"
58
# Regexes that match TF/Flax/PT model names.
59
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
60
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
61
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
62
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
63
64
65
# This is to make sure the diffusers module imported is the one in the repo.
66
spec = importlib.util.spec_from_file_location(
67
"diffusers",
68
os.path.join(TRANSFORMERS_PATH, "__init__.py"),
69
submodule_search_locations=[TRANSFORMERS_PATH],
70
)
71
diffusers_module = spec.loader.load_module()
72
73
74
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
75
def camel_case_split(identifier):
76
"Split a camelcased `identifier` into words."
77
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
78
return [m.group(0) for m in matches]
79
80
81
def _center_text(text, width):
82
text_length = 2 if text == "✅" or text == "❌" else len(text)
83
left_indent = (width - text_length) // 2
84
right_indent = width - text_length - left_indent
85
return " " * left_indent + text + " " * right_indent
86
87
88
def get_model_table_from_auto_modules():
89
"""Generates an up-to-date model table from the content of the auto modules."""
90
# Dictionary model names to config.
91
config_mapping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
92
model_name_to_config = {
93
name: config_mapping_names[code]
94
for code, name in diffusers_module.MODEL_NAMES_MAPPING.items()
95
if code in config_mapping_names
96
}
97
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
98
99
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
100
slow_tokenizers = collections.defaultdict(bool)
101
fast_tokenizers = collections.defaultdict(bool)
102
pt_models = collections.defaultdict(bool)
103
tf_models = collections.defaultdict(bool)
104
flax_models = collections.defaultdict(bool)
105
106
# Let's lookup through all diffusers object (once).
107
for attr_name in dir(diffusers_module):
108
lookup_dict = None
109
if attr_name.endswith("Tokenizer"):
110
lookup_dict = slow_tokenizers
111
attr_name = attr_name[:-9]
112
elif attr_name.endswith("TokenizerFast"):
113
lookup_dict = fast_tokenizers
114
attr_name = attr_name[:-13]
115
elif _re_tf_models.match(attr_name) is not None:
116
lookup_dict = tf_models
117
attr_name = _re_tf_models.match(attr_name).groups()[0]
118
elif _re_flax_models.match(attr_name) is not None:
119
lookup_dict = flax_models
120
attr_name = _re_flax_models.match(attr_name).groups()[0]
121
elif _re_pt_models.match(attr_name) is not None:
122
lookup_dict = pt_models
123
attr_name = _re_pt_models.match(attr_name).groups()[0]
124
125
if lookup_dict is not None:
126
while len(attr_name) > 0:
127
if attr_name in model_name_to_prefix.values():
128
lookup_dict[attr_name] = True
129
break
130
# Try again after removing the last word in the name
131
attr_name = "".join(camel_case_split(attr_name)[:-1])
132
133
# Let's build that table!
134
model_names = list(model_name_to_config.keys())
135
model_names.sort(key=str.lower)
136
columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"]
137
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
138
widths = [len(c) + 2 for c in columns]
139
widths[0] = max([len(name) for name in model_names]) + 2
140
141
# Build the table per se
142
table = "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n"
143
# Use ":-----:" format to center-aligned table cell texts
144
table += "|" + "|".join([":" + "-" * (w - 2) + ":" for w in widths]) + "|\n"
145
146
check = {True: "✅", False: "❌"}
147
for name in model_names:
148
prefix = model_name_to_prefix[name]
149
line = [
150
name,
151
check[slow_tokenizers[prefix]],
152
check[fast_tokenizers[prefix]],
153
check[pt_models[prefix]],
154
check[tf_models[prefix]],
155
check[flax_models[prefix]],
156
]
157
table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n"
158
return table
159
160
161
def check_model_table(overwrite=False):
162
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
163
current_table, start_index, end_index, lines = _find_text_in_file(
164
filename=os.path.join(PATH_TO_DOCS, "index.mdx"),
165
start_prompt="<!--This table is updated automatically from the auto modules",
166
end_prompt="<!-- End table-->",
167
)
168
new_table = get_model_table_from_auto_modules()
169
170
if current_table != new_table:
171
if overwrite:
172
with open(os.path.join(PATH_TO_DOCS, "index.mdx"), "w", encoding="utf-8", newline="\n") as f:
173
f.writelines(lines[:start_index] + [new_table] + lines[end_index:])
174
else:
175
raise ValueError(
176
"The model table in the `index.mdx` has not been updated. Run `make fix-copies` to fix this."
177
)
178
179
180
if __name__ == "__main__":
181
parser = argparse.ArgumentParser()
182
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
183
args = parser.parse_args()
184
185
check_model_table(args.fix_and_overwrite)
186
187