Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/check_repo.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import importlib
17
import inspect
18
import os
19
import re
20
import warnings
21
from collections import OrderedDict
22
from difflib import get_close_matches
23
from pathlib import Path
24
25
from diffusers.models.auto import get_values
26
from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_tf_available, is_torch_available
27
28
29
# All paths are set with the intent you should run this script from the root of the repo with the command
30
# python utils/check_repo.py
31
PATH_TO_DIFFUSERS = "src/diffusers"
32
PATH_TO_TESTS = "tests"
33
PATH_TO_DOC = "docs/source/en"
34
35
# Update this list with models that are supposed to be private.
36
PRIVATE_MODELS = [
37
"DPRSpanPredictor",
38
"RealmBertModel",
39
"T5Stack",
40
"TFDPRSpanPredictor",
41
]
42
43
# Update this list for models that are not tested with a comment explaining the reason it should not be.
44
# Being in this list is an exception and should **not** be the rule.
45
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
46
# models to ignore for not tested
47
"OPTDecoder", # Building part of bigger (tested) model.
48
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
49
"SegformerDecodeHead", # Building part of bigger (tested) model.
50
"PLBartEncoder", # Building part of bigger (tested) model.
51
"PLBartDecoder", # Building part of bigger (tested) model.
52
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
53
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
54
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
55
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
56
"DetrEncoder", # Building part of bigger (tested) model.
57
"DetrDecoder", # Building part of bigger (tested) model.
58
"DetrDecoderWrapper", # Building part of bigger (tested) model.
59
"M2M100Encoder", # Building part of bigger (tested) model.
60
"M2M100Decoder", # Building part of bigger (tested) model.
61
"Speech2TextEncoder", # Building part of bigger (tested) model.
62
"Speech2TextDecoder", # Building part of bigger (tested) model.
63
"LEDEncoder", # Building part of bigger (tested) model.
64
"LEDDecoder", # Building part of bigger (tested) model.
65
"BartDecoderWrapper", # Building part of bigger (tested) model.
66
"BartEncoder", # Building part of bigger (tested) model.
67
"BertLMHeadModel", # Needs to be setup as decoder.
68
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
69
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
70
"BlenderbotEncoder", # Building part of bigger (tested) model.
71
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
72
"MBartEncoder", # Building part of bigger (tested) model.
73
"MBartDecoderWrapper", # Building part of bigger (tested) model.
74
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
75
"MegatronBertEncoder", # Building part of bigger (tested) model.
76
"MegatronBertDecoder", # Building part of bigger (tested) model.
77
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
78
"PegasusEncoder", # Building part of bigger (tested) model.
79
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
80
"DPREncoder", # Building part of bigger (tested) model.
81
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
82
"RealmBertModel", # Building part of bigger (tested) model.
83
"RealmReader", # Not regular model.
84
"RealmScorer", # Not regular model.
85
"RealmForOpenQA", # Not regular model.
86
"ReformerForMaskedLM", # Needs to be setup as decoder.
87
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
88
"TFDPREncoder", # Building part of bigger (tested) model.
89
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFModelMixin ?)
90
"TFRobertaForMultipleChoice", # TODO: fix
91
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
92
"SeparableConv1D", # Building part of bigger (tested) model.
93
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
94
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
95
"OPTDecoderWrapper",
96
]
97
98
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
99
# trigger the common tests.
100
TEST_FILES_WITH_NO_COMMON_TESTS = [
101
"models/decision_transformer/test_modeling_decision_transformer.py",
102
"models/camembert/test_modeling_camembert.py",
103
"models/mt5/test_modeling_flax_mt5.py",
104
"models/mbart/test_modeling_mbart.py",
105
"models/mt5/test_modeling_mt5.py",
106
"models/pegasus/test_modeling_pegasus.py",
107
"models/camembert/test_modeling_tf_camembert.py",
108
"models/mt5/test_modeling_tf_mt5.py",
109
"models/xlm_roberta/test_modeling_tf_xlm_roberta.py",
110
"models/xlm_roberta/test_modeling_flax_xlm_roberta.py",
111
"models/xlm_prophetnet/test_modeling_xlm_prophetnet.py",
112
"models/xlm_roberta/test_modeling_xlm_roberta.py",
113
"models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
114
"models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
115
"models/decision_transformer/test_modeling_decision_transformer.py",
116
]
117
118
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
119
# should **not** be the rule.
120
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
121
# models to ignore for model xxx mapping
122
"DPTForDepthEstimation",
123
"DecisionTransformerGPT2Model",
124
"GLPNForDepthEstimation",
125
"ViltForQuestionAnswering",
126
"ViltForImagesAndTextClassification",
127
"ViltForImageAndTextRetrieval",
128
"ViltForMaskedLM",
129
"XGLMEncoder",
130
"XGLMDecoder",
131
"XGLMDecoderWrapper",
132
"PerceiverForMultimodalAutoencoding",
133
"PerceiverForOpticalFlow",
134
"SegformerDecodeHead",
135
"FlaxBeitForMaskedImageModeling",
136
"PLBartEncoder",
137
"PLBartDecoder",
138
"PLBartDecoderWrapper",
139
"BeitForMaskedImageModeling",
140
"CLIPTextModel",
141
"CLIPVisionModel",
142
"TFCLIPTextModel",
143
"TFCLIPVisionModel",
144
"FlaxCLIPTextModel",
145
"FlaxCLIPVisionModel",
146
"FlaxWav2Vec2ForCTC",
147
"DetrForSegmentation",
148
"DPRReader",
149
"FlaubertForQuestionAnswering",
150
"FlavaImageCodebook",
151
"FlavaTextModel",
152
"FlavaImageModel",
153
"FlavaMultimodalModel",
154
"GPT2DoubleHeadsModel",
155
"LukeForMaskedLM",
156
"LukeForEntityClassification",
157
"LukeForEntityPairClassification",
158
"LukeForEntitySpanClassification",
159
"OpenAIGPTDoubleHeadsModel",
160
"RagModel",
161
"RagSequenceForGeneration",
162
"RagTokenForGeneration",
163
"RealmEmbedder",
164
"RealmForOpenQA",
165
"RealmScorer",
166
"RealmReader",
167
"TFDPRReader",
168
"TFGPT2DoubleHeadsModel",
169
"TFOpenAIGPTDoubleHeadsModel",
170
"TFRagModel",
171
"TFRagSequenceForGeneration",
172
"TFRagTokenForGeneration",
173
"Wav2Vec2ForCTC",
174
"HubertForCTC",
175
"SEWForCTC",
176
"SEWDForCTC",
177
"XLMForQuestionAnswering",
178
"XLNetForQuestionAnswering",
179
"SeparableConv1D",
180
"VisualBertForRegionToPhraseAlignment",
181
"VisualBertForVisualReasoning",
182
"VisualBertForQuestionAnswering",
183
"VisualBertForMultipleChoice",
184
"TFWav2Vec2ForCTC",
185
"TFHubertForCTC",
186
"MaskFormerForInstanceSegmentation",
187
]
188
189
# Update this list for models that have multiple model types for the same
190
# model doc
191
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
192
[
193
("data2vec-text", "data2vec"),
194
("data2vec-audio", "data2vec"),
195
("data2vec-vision", "data2vec"),
196
]
197
)
198
199
200
# This is to make sure the transformers module imported is the one in the repo.
201
spec = importlib.util.spec_from_file_location(
202
"diffusers",
203
os.path.join(PATH_TO_DIFFUSERS, "__init__.py"),
204
submodule_search_locations=[PATH_TO_DIFFUSERS],
205
)
206
diffusers = spec.loader.load_module()
207
208
209
def check_model_list():
210
"""Check the model list inside the transformers library."""
211
# Get the models from the directory structure of `src/diffusers/models/`
212
models_dir = os.path.join(PATH_TO_DIFFUSERS, "models")
213
_models = []
214
for model in os.listdir(models_dir):
215
model_dir = os.path.join(models_dir, model)
216
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
217
_models.append(model)
218
219
# Get the models from the directory structure of `src/transformers/models/`
220
models = [model for model in dir(diffusers.models) if not model.startswith("__")]
221
222
missing_models = sorted(list(set(_models).difference(models)))
223
if missing_models:
224
raise Exception(
225
f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}."
226
)
227
228
229
# If some modeling modules should be ignored for all checks, they should be added in the nested list
230
# _ignore_modules of this function.
231
def get_model_modules():
232
"""Get the model modules inside the transformers library."""
233
_ignore_modules = [
234
"modeling_auto",
235
"modeling_encoder_decoder",
236
"modeling_marian",
237
"modeling_mmbt",
238
"modeling_outputs",
239
"modeling_retribert",
240
"modeling_utils",
241
"modeling_flax_auto",
242
"modeling_flax_encoder_decoder",
243
"modeling_flax_utils",
244
"modeling_speech_encoder_decoder",
245
"modeling_flax_speech_encoder_decoder",
246
"modeling_flax_vision_encoder_decoder",
247
"modeling_transfo_xl_utilities",
248
"modeling_tf_auto",
249
"modeling_tf_encoder_decoder",
250
"modeling_tf_outputs",
251
"modeling_tf_pytorch_utils",
252
"modeling_tf_utils",
253
"modeling_tf_transfo_xl_utilities",
254
"modeling_tf_vision_encoder_decoder",
255
"modeling_vision_encoder_decoder",
256
]
257
modules = []
258
for model in dir(diffusers.models):
259
# There are some magic dunder attributes in the dir, we ignore them
260
if not model.startswith("__"):
261
model_module = getattr(diffusers.models, model)
262
for submodule in dir(model_module):
263
if submodule.startswith("modeling") and submodule not in _ignore_modules:
264
modeling_module = getattr(model_module, submodule)
265
if inspect.ismodule(modeling_module):
266
modules.append(modeling_module)
267
return modules
268
269
270
def get_models(module, include_pretrained=False):
271
"""Get the objects in module that are models."""
272
models = []
273
model_classes = (diffusers.ModelMixin, diffusers.TFModelMixin, diffusers.FlaxModelMixin)
274
for attr_name in dir(module):
275
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
276
continue
277
attr = getattr(module, attr_name)
278
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
279
models.append((attr_name, attr))
280
return models
281
282
283
def is_a_private_model(model):
284
"""Returns True if the model should not be in the main init."""
285
if model in PRIVATE_MODELS:
286
return True
287
288
# Wrapper, Encoder and Decoder are all privates
289
if model.endswith("Wrapper"):
290
return True
291
if model.endswith("Encoder"):
292
return True
293
if model.endswith("Decoder"):
294
return True
295
return False
296
297
298
def check_models_are_in_init():
299
"""Checks all models defined in the library are in the main init."""
300
models_not_in_init = []
301
dir_transformers = dir(diffusers)
302
for module in get_model_modules():
303
models_not_in_init += [
304
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
305
]
306
307
# Remove private models
308
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
309
if len(models_not_in_init) > 0:
310
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
311
312
313
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
314
# nested list _ignore_files of this function.
315
def get_model_test_files():
316
"""Get the model test files.
317
318
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
319
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
320
"""
321
322
_ignore_files = [
323
"test_modeling_common",
324
"test_modeling_encoder_decoder",
325
"test_modeling_flax_encoder_decoder",
326
"test_modeling_flax_speech_encoder_decoder",
327
"test_modeling_marian",
328
"test_modeling_tf_common",
329
"test_modeling_tf_encoder_decoder",
330
]
331
test_files = []
332
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
333
model_test_root = os.path.join(PATH_TO_TESTS, "models")
334
model_test_dirs = []
335
for x in os.listdir(model_test_root):
336
x = os.path.join(model_test_root, x)
337
if os.path.isdir(x):
338
model_test_dirs.append(x)
339
340
for target_dir in [PATH_TO_TESTS] + model_test_dirs:
341
for file_or_dir in os.listdir(target_dir):
342
path = os.path.join(target_dir, file_or_dir)
343
if os.path.isfile(path):
344
filename = os.path.split(path)[-1]
345
if "test_modeling" in filename and os.path.splitext(filename)[0] not in _ignore_files:
346
file = os.path.join(*path.split(os.sep)[1:])
347
test_files.append(file)
348
349
return test_files
350
351
352
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
353
# for the all_model_classes variable.
354
def find_tested_models(test_file):
355
"""Parse the content of test_file to detect what's in all_model_classes"""
356
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
357
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
358
content = f.read()
359
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
360
# Check with one less parenthesis as well
361
all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
362
if len(all_models) > 0:
363
model_tested = []
364
for entry in all_models:
365
for line in entry.split(","):
366
name = line.strip()
367
if len(name) > 0:
368
model_tested.append(name)
369
return model_tested
370
371
372
def check_models_are_tested(module, test_file):
373
"""Check models defined in module are tested in test_file."""
374
# XxxModelMixin are not tested
375
defined_models = get_models(module)
376
tested_models = find_tested_models(test_file)
377
if tested_models is None:
378
if test_file.replace(os.path.sep, "/") in TEST_FILES_WITH_NO_COMMON_TESTS:
379
return
380
return [
381
f"{test_file} should define `all_model_classes` to apply common tests to the models it tests. "
382
+ "If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file "
383
+ "`utils/check_repo.py`."
384
]
385
failures = []
386
for model_name, _ in defined_models:
387
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
388
failures.append(
389
f"{model_name} is defined in {module.__name__} but is not tested in "
390
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
391
+ "If common tests should not applied to that model, add its name to `IGNORE_NON_TESTED`"
392
+ "in the file `utils/check_repo.py`."
393
)
394
return failures
395
396
397
def check_all_models_are_tested():
398
"""Check all models are properly tested."""
399
modules = get_model_modules()
400
test_files = get_model_test_files()
401
failures = []
402
for module in modules:
403
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
404
if len(test_file) == 0:
405
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
406
elif len(test_file) > 1:
407
failures.append(f"{module.__name__} has several test files: {test_file}.")
408
else:
409
test_file = test_file[0]
410
new_failures = check_models_are_tested(module, test_file)
411
if new_failures is not None:
412
failures += new_failures
413
if len(failures) > 0:
414
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
415
416
417
def get_all_auto_configured_models():
418
"""Return the list of all models in at least one auto class."""
419
result = set() # To avoid duplicates we concatenate all model classes in a set.
420
if is_torch_available():
421
for attr_name in dir(diffusers.models.auto.modeling_auto):
422
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
423
result = result | set(get_values(getattr(diffusers.models.auto.modeling_auto, attr_name)))
424
if is_tf_available():
425
for attr_name in dir(diffusers.models.auto.modeling_tf_auto):
426
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
427
result = result | set(get_values(getattr(diffusers.models.auto.modeling_tf_auto, attr_name)))
428
if is_flax_available():
429
for attr_name in dir(diffusers.models.auto.modeling_flax_auto):
430
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
431
result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name)))
432
return [cls for cls in result]
433
434
435
def ignore_unautoclassed(model_name):
436
"""Rules to determine if `name` should be in an auto class."""
437
# Special white list
438
if model_name in IGNORE_NON_AUTO_CONFIGURED:
439
return True
440
# Encoder and Decoder should be ignored
441
if "Encoder" in model_name or "Decoder" in model_name:
442
return True
443
return False
444
445
446
def check_models_are_auto_configured(module, all_auto_models):
447
"""Check models defined in module are each in an auto class."""
448
defined_models = get_models(module)
449
failures = []
450
for model_name, _ in defined_models:
451
if model_name not in all_auto_models and not ignore_unautoclassed(model_name):
452
failures.append(
453
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
454
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
455
"`utils/check_repo.py`."
456
)
457
return failures
458
459
460
def check_all_models_are_auto_configured():
461
"""Check all models are each in an auto class."""
462
missing_backends = []
463
if not is_torch_available():
464
missing_backends.append("PyTorch")
465
if not is_tf_available():
466
missing_backends.append("TensorFlow")
467
if not is_flax_available():
468
missing_backends.append("Flax")
469
if len(missing_backends) > 0:
470
missing = ", ".join(missing_backends)
471
if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
472
raise Exception(
473
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
474
f"Transformers repo, the following are missing: {missing}."
475
)
476
else:
477
warnings.warn(
478
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
479
f"Transformers repo, the following are missing: {missing}. While it's probably fine as long as you "
480
"didn't make any change in one of those backends modeling files, you should probably execute the "
481
"command above to be on the safe side."
482
)
483
modules = get_model_modules()
484
all_auto_models = get_all_auto_configured_models()
485
failures = []
486
for module in modules:
487
new_failures = check_models_are_auto_configured(module, all_auto_models)
488
if new_failures is not None:
489
failures += new_failures
490
if len(failures) > 0:
491
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
492
493
494
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
495
496
497
def check_decorator_order(filename):
498
"""Check that in the test file `filename` the slow decorator is always last."""
499
with open(filename, "r", encoding="utf-8", newline="\n") as f:
500
lines = f.readlines()
501
decorator_before = None
502
errors = []
503
for i, line in enumerate(lines):
504
search = _re_decorator.search(line)
505
if search is not None:
506
decorator_name = search.groups()[0]
507
if decorator_before is not None and decorator_name.startswith("parameterized"):
508
errors.append(i)
509
decorator_before = decorator_name
510
elif decorator_before is not None:
511
decorator_before = None
512
return errors
513
514
515
def check_all_decorator_order():
516
"""Check that in all test files, the slow decorator is always last."""
517
errors = []
518
for fname in os.listdir(PATH_TO_TESTS):
519
if fname.endswith(".py"):
520
filename = os.path.join(PATH_TO_TESTS, fname)
521
new_errors = check_decorator_order(filename)
522
errors += [f"- {filename}, line {i}" for i in new_errors]
523
if len(errors) > 0:
524
msg = "\n".join(errors)
525
raise ValueError(
526
"The parameterized decorator (and its variants) should always be first, but this is not the case in the"
527
f" following files:\n{msg}"
528
)
529
530
531
def find_all_documented_objects():
532
"""Parse the content of all doc files to detect which classes and functions it documents"""
533
documented_obj = []
534
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
535
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
536
content = f.read()
537
raw_doc_objs = re.findall(r"(?:autoclass|autofunction):: transformers.(\S+)\s+", content)
538
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
539
for doc_file in Path(PATH_TO_DOC).glob("**/*.mdx"):
540
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
541
content = f.read()
542
raw_doc_objs = re.findall("\[\[autodoc\]\]\s+(\S+)\s+", content)
543
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
544
return documented_obj
545
546
547
# One good reason for not being documented is to be deprecated. Put in this list deprecated objects.
548
DEPRECATED_OBJECTS = [
549
"AutoModelWithLMHead",
550
"BartPretrainedModel",
551
"DataCollator",
552
"DataCollatorForSOP",
553
"GlueDataset",
554
"GlueDataTrainingArguments",
555
"LineByLineTextDataset",
556
"LineByLineWithRefDataset",
557
"LineByLineWithSOPTextDataset",
558
"PretrainedBartModel",
559
"PretrainedFSMTModel",
560
"SingleSentenceClassificationProcessor",
561
"SquadDataTrainingArguments",
562
"SquadDataset",
563
"SquadExample",
564
"SquadFeatures",
565
"SquadV1Processor",
566
"SquadV2Processor",
567
"TFAutoModelWithLMHead",
568
"TFBartPretrainedModel",
569
"TextDataset",
570
"TextDatasetForNextSentencePrediction",
571
"Wav2Vec2ForMaskedLM",
572
"Wav2Vec2Tokenizer",
573
"glue_compute_metrics",
574
"glue_convert_examples_to_features",
575
"glue_output_modes",
576
"glue_processors",
577
"glue_tasks_num_labels",
578
"squad_convert_examples_to_features",
579
"xnli_compute_metrics",
580
"xnli_output_modes",
581
"xnli_processors",
582
"xnli_tasks_num_labels",
583
"TFTrainer",
584
"TFTrainingArguments",
585
]
586
587
# Exceptionally, some objects should not be documented after all rules passed.
588
# ONLY PUT SOMETHING IN THIS LIST AS A LAST RESORT!
589
UNDOCUMENTED_OBJECTS = [
590
"AddedToken", # This is a tokenizers class.
591
"BasicTokenizer", # Internal, should never have been in the main init.
592
"CharacterTokenizer", # Internal, should never have been in the main init.
593
"DPRPretrainedReader", # Like an Encoder.
594
"DummyObject", # Just picked by mistake sometimes.
595
"MecabTokenizer", # Internal, should never have been in the main init.
596
"ModelCard", # Internal type.
597
"SqueezeBertModule", # Internal building block (should have been called SqueezeBertLayer)
598
"TFDPRPretrainedReader", # Like an Encoder.
599
"TransfoXLCorpus", # Internal type.
600
"WordpieceTokenizer", # Internal, should never have been in the main init.
601
"absl", # External module
602
"add_end_docstrings", # Internal, should never have been in the main init.
603
"add_start_docstrings", # Internal, should never have been in the main init.
604
"cached_path", # Internal used for downloading models.
605
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
606
"logger", # Internal logger
607
"logging", # External module
608
"requires_backends", # Internal function
609
]
610
611
# This list should be empty. Objects in it should get their own doc page.
612
SHOULD_HAVE_THEIR_OWN_PAGE = [
613
# Benchmarks
614
"PyTorchBenchmark",
615
"PyTorchBenchmarkArguments",
616
"TensorFlowBenchmark",
617
"TensorFlowBenchmarkArguments",
618
]
619
620
621
def ignore_undocumented(name):
622
"""Rules to determine if `name` should be undocumented."""
623
# NOT DOCUMENTED ON PURPOSE.
624
# Constants uppercase are not documented.
625
if name.isupper():
626
return True
627
# ModelMixins / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
628
if (
629
name.endswith("ModelMixin")
630
or name.endswith("Decoder")
631
or name.endswith("Encoder")
632
or name.endswith("Layer")
633
or name.endswith("Embeddings")
634
or name.endswith("Attention")
635
):
636
return True
637
# Submodules are not documented.
638
if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile(
639
os.path.join(PATH_TO_DIFFUSERS, f"{name}.py")
640
):
641
return True
642
# All load functions are not documented.
643
if name.startswith("load_tf") or name.startswith("load_pytorch"):
644
return True
645
# is_xxx_available functions are not documented.
646
if name.startswith("is_") and name.endswith("_available"):
647
return True
648
# Deprecated objects are not documented.
649
if name in DEPRECATED_OBJECTS or name in UNDOCUMENTED_OBJECTS:
650
return True
651
# MMBT model does not really work.
652
if name.startswith("MMBT"):
653
return True
654
if name in SHOULD_HAVE_THEIR_OWN_PAGE:
655
return True
656
return False
657
658
659
def check_all_objects_are_documented():
660
"""Check all models are properly documented."""
661
documented_objs = find_all_documented_objects()
662
modules = diffusers._modules
663
objects = [c for c in dir(diffusers) if c not in modules and not c.startswith("_")]
664
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
665
if len(undocumented_objs) > 0:
666
raise Exception(
667
"The following objects are in the public init so should be documented:\n - "
668
+ "\n - ".join(undocumented_objs)
669
)
670
check_docstrings_are_in_md()
671
check_model_type_doc_match()
672
673
674
def check_model_type_doc_match():
675
"""Check all doc pages have a corresponding model type."""
676
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
677
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
678
679
model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
680
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
681
682
errors = []
683
for m in model_docs:
684
if m not in model_types and m != "auto":
685
close_matches = get_close_matches(m, model_types)
686
error_message = f"{m} is not a proper model identifier."
687
if len(close_matches) > 0:
688
close_matches = "/".join(close_matches)
689
error_message += f" Did you mean {close_matches}?"
690
errors.append(error_message)
691
692
if len(errors) > 0:
693
raise ValueError(
694
"Some model doc pages do not match any existing model type:\n"
695
+ "\n".join(errors)
696
+ "\nYou can add any missing model type to the `MODEL_NAMES_MAPPING` constant in "
697
"models/auto/configuration_auto.py."
698
)
699
700
701
# Re pattern to catch :obj:`xx`, :class:`xx`, :func:`xx` or :meth:`xx`.
702
_re_rst_special_words = re.compile(r":(?:obj|func|class|meth):`([^`]+)`")
703
# Re pattern to catch things between double backquotes.
704
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
705
# Re pattern to catch example introduction.
706
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
707
708
709
def is_rst_docstring(docstring):
710
"""
711
Returns `True` if `docstring` is written in rst.
712
"""
713
if _re_rst_special_words.search(docstring) is not None:
714
return True
715
if _re_double_backquotes.search(docstring) is not None:
716
return True
717
if _re_rst_example.search(docstring) is not None:
718
return True
719
return False
720
721
722
def check_docstrings_are_in_md():
723
"""Check all docstrings are in md"""
724
files_with_rst = []
725
for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"):
726
with open(file, "r") as f:
727
code = f.read()
728
docstrings = code.split('"""')
729
730
for idx, docstring in enumerate(docstrings):
731
if idx % 2 == 0 or not is_rst_docstring(docstring):
732
continue
733
files_with_rst.append(file)
734
break
735
736
if len(files_with_rst) > 0:
737
raise ValueError(
738
"The following files have docstrings written in rst:\n"
739
+ "\n".join([f"- {f}" for f in files_with_rst])
740
+ "\nTo fix this run `doc-builder convert path_to_py_file` after installing `doc-builder`\n"
741
"(`pip install git+https://github.com/huggingface/doc-builder`)"
742
)
743
744
745
def check_repo_quality():
746
"""Check all models are properly tested and documented."""
747
print("Checking all models are included.")
748
check_model_list()
749
print("Checking all models are public.")
750
check_models_are_in_init()
751
print("Checking all models are properly tested.")
752
check_all_decorator_order()
753
check_all_models_are_tested()
754
print("Checking all objects are properly documented.")
755
check_all_objects_are_documented()
756
print("Checking all models are in at least one auto class.")
757
check_all_models_are_auto_configured()
758
759
760
if __name__ == "__main__":
761
check_repo_quality()
762
763