Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/utils/check_doc_toc.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import argparse
17
from collections import defaultdict
18
19
import yaml
20
21
22
PATH_TO_TOC = "docs/source/en/_toctree.yml"
23
24
25
def clean_doc_toc(doc_list):
26
"""
27
Cleans the table of content of the model documentation by removing duplicates and sorting models alphabetically.
28
"""
29
counts = defaultdict(int)
30
overview_doc = []
31
new_doc_list = []
32
for doc in doc_list:
33
if "local" in doc:
34
counts[doc["local"]] += 1
35
36
if doc["title"].lower() == "overview":
37
overview_doc.append({"local": doc["local"], "title": doc["title"]})
38
else:
39
new_doc_list.append(doc)
40
41
doc_list = new_doc_list
42
duplicates = [key for key, value in counts.items() if value > 1]
43
44
new_doc = []
45
for duplicate_key in duplicates:
46
titles = list(set(doc["title"] for doc in doc_list if doc["local"] == duplicate_key))
47
if len(titles) > 1:
48
raise ValueError(
49
f"{duplicate_key} is present several times in the documentation table of content at "
50
"`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the "
51
"others."
52
)
53
# Only add this once
54
new_doc.append({"local": duplicate_key, "title": titles[0]})
55
56
# Add none duplicate-keys
57
new_doc.extend([doc for doc in doc_list if "local" not in counts or counts[doc["local"]] == 1])
58
new_doc = sorted(new_doc, key=lambda s: s["title"].lower())
59
60
# "overview" gets special treatment and is always first
61
if len(overview_doc) > 1:
62
raise ValueError("{doc_list} has two 'overview' docs which is not allowed.")
63
64
overview_doc.extend(new_doc)
65
66
# Sort
67
return overview_doc
68
69
70
def check_scheduler_doc(overwrite=False):
71
with open(PATH_TO_TOC, encoding="utf-8") as f:
72
content = yaml.safe_load(f.read())
73
74
# Get to the API doc
75
api_idx = 0
76
while content[api_idx]["title"] != "API":
77
api_idx += 1
78
api_doc = content[api_idx]["sections"]
79
80
# Then to the model doc
81
scheduler_idx = 0
82
while api_doc[scheduler_idx]["title"] != "Schedulers":
83
scheduler_idx += 1
84
85
scheduler_doc = api_doc[scheduler_idx]["sections"]
86
new_scheduler_doc = clean_doc_toc(scheduler_doc)
87
88
diff = False
89
if new_scheduler_doc != scheduler_doc:
90
diff = True
91
if overwrite:
92
api_doc[scheduler_idx]["sections"] = new_scheduler_doc
93
94
if diff:
95
if overwrite:
96
content[api_idx]["sections"] = api_doc
97
with open(PATH_TO_TOC, "w", encoding="utf-8") as f:
98
f.write(yaml.dump(content, allow_unicode=True))
99
else:
100
raise ValueError(
101
"The model doc part of the table of content is not properly sorted, run `make style` to fix this."
102
)
103
104
105
def check_pipeline_doc(overwrite=False):
106
with open(PATH_TO_TOC, encoding="utf-8") as f:
107
content = yaml.safe_load(f.read())
108
109
# Get to the API doc
110
api_idx = 0
111
while content[api_idx]["title"] != "API":
112
api_idx += 1
113
api_doc = content[api_idx]["sections"]
114
115
# Then to the model doc
116
pipeline_idx = 0
117
while api_doc[pipeline_idx]["title"] != "Pipelines":
118
pipeline_idx += 1
119
120
diff = False
121
pipeline_docs = api_doc[pipeline_idx]["sections"]
122
new_pipeline_docs = []
123
124
# sort sub pipeline docs
125
for pipeline_doc in pipeline_docs:
126
if "section" in pipeline_doc:
127
sub_pipeline_doc = pipeline_doc["section"]
128
new_sub_pipeline_doc = clean_doc_toc(sub_pipeline_doc)
129
if overwrite:
130
pipeline_doc["section"] = new_sub_pipeline_doc
131
new_pipeline_docs.append(pipeline_doc)
132
133
# sort overall pipeline doc
134
new_pipeline_docs = clean_doc_toc(new_pipeline_docs)
135
136
if new_pipeline_docs != pipeline_docs:
137
diff = True
138
if overwrite:
139
api_doc[pipeline_idx]["sections"] = new_pipeline_docs
140
141
if diff:
142
if overwrite:
143
content[api_idx]["sections"] = api_doc
144
with open(PATH_TO_TOC, "w", encoding="utf-8") as f:
145
f.write(yaml.dump(content, allow_unicode=True))
146
else:
147
raise ValueError(
148
"The model doc part of the table of content is not properly sorted, run `make style` to fix this."
149
)
150
151
152
if __name__ == "__main__":
153
parser = argparse.ArgumentParser()
154
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
155
args = parser.parse_args()
156
157
check_scheduler_doc(args.fix_and_overwrite)
158
check_pipeline_doc(args.fix_and_overwrite)
159
160