Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/scripts/docstrings.py
7804 views
1
"""Lightweight fork of Keras-Autodocs."""
2
3
import warnings
4
import black
5
import re
6
import inspect
7
import importlib
8
import itertools
9
from collections import defaultdict
10
from collections import namedtuple
11
import copy
12
13
import render_presets
14
15
16
# Maximum number of links to guides and examples
17
MAX_EXAMPLE_LINKS = 8
18
19
ExampleInfo = namedtuple("ExampleInfo", ["url", "title", "is_guide"])
20
21
22
class KerasDocumentationGenerator:
23
def __init__(self, project_url=None):
24
self.project_url = project_url
25
self.api_to_example = defaultdict(list)
26
27
def add_example_apis(self, url, title, is_guide, apis):
28
example = ExampleInfo(url, title, is_guide)
29
for api in apis:
30
self.api_to_example[api].append(example)
31
32
def process_docstring(self, docstring):
33
docstring = docstring.replace("Args:", "# Arguments")
34
docstring = docstring.replace("Arguments:", "# Arguments")
35
docstring = docstring.replace("Attributes:", "# Attributes")
36
docstring = docstring.replace("Returns:", "# Returns")
37
docstring = docstring.replace("Raises:", "# Raises")
38
docstring = docstring.replace("Input shape:", "# Input shape")
39
docstring = docstring.replace("Output shape:", "# Output shape")
40
docstring = docstring.replace("Call arguments:", "# Call arguments")
41
docstring = docstring.replace("Returns:", "# Returns")
42
docstring = docstring.replace("Example:", "# Example\n")
43
docstring = docstring.replace("Examples:", "# Examples\n")
44
45
docstring = re.sub(r"\nReference:\n\s*", "\n**Reference**\n\n", docstring)
46
docstring = re.sub(r"\nReferences:\n\s*", "\n**References**\n\n", docstring)
47
48
# Fix typo
49
docstring = docstring.replace("\n >>> ", "\n>>> ")
50
51
lines = docstring.split("\n")
52
doctest_lines = []
53
usable_lines = []
54
55
def flush_docstest(usable_lines, doctest_lines):
56
usable_lines.append("```python")
57
usable_lines += doctest_lines
58
usable_lines.append("```")
59
usable_lines.append("")
60
61
for line in lines:
62
if doctest_lines:
63
if not line or set(line) == {" "}:
64
flush_docstest(usable_lines, doctest_lines)
65
doctest_lines = []
66
else:
67
doctest_lines.append(line)
68
else:
69
if line.startswith(">>>"):
70
doctest_lines.append(line)
71
else:
72
usable_lines.append(line)
73
if doctest_lines:
74
flush_docstest(usable_lines, doctest_lines)
75
docstring = "\n".join(usable_lines)
76
77
return process_docstring(docstring)
78
79
def process_signature(self, signature):
80
signature = signature.replace("tensorflow.keras", "tf.keras")
81
signature = signature.replace("*args, **kwargs", "")
82
return signature
83
84
def render(self, element):
85
if isinstance(element, str):
86
object_ = import_object(element)
87
if ismethod(object_):
88
# we remove the modules when displaying the methods
89
signature_override = ".".join(element.split(".")[-2:])
90
else:
91
signature_override = element
92
else:
93
signature_override = None
94
object_ = element
95
return self.render_from_object(object_, signature_override, element)
96
97
def render_from_object(self, object_, signature_override: str, element):
98
subblocks = []
99
source_link = make_source_link(object_, self.project_url)
100
if source_link is not None:
101
subblocks.append(source_link)
102
signature = get_signature(object_, signature_override)
103
signature = self.process_signature(signature)
104
subblocks.append(f"### `{get_name(object_)}` {get_type(object_)}\n")
105
subblocks.append(code_snippet(signature))
106
107
docstring = inspect.getdoc(object_)
108
if docstring:
109
docstring = self.process_docstring(docstring)
110
subblocks.append(docstring)
111
# Render preset table for KerasCV and KerasHub
112
if element.endswith("from_preset"):
113
table = render_presets.render_table(
114
import_object(element.rsplit(".", 1)[0])
115
)
116
if table is not None:
117
subblocks.append(table)
118
119
examples = self.api_to_example.get(element, None)
120
if examples:
121
subblocks.append(get_examples_block(element, examples))
122
123
return "\n\n".join(subblocks) + "\n\n----\n\n"
124
125
126
def ismethod(function):
127
return get_class_from_method(function) is not None
128
129
130
def import_object(string: str):
131
"""Import an object from a string.
132
133
The object can be a function, class or method.
134
For example: `'keras.layers.Dense.get_weights'` is valid.
135
"""
136
last_object_got = None
137
seen_names = []
138
for name in string.split("."):
139
seen_names.append(name)
140
try:
141
last_object_got = importlib.import_module(".".join(seen_names))
142
except ModuleNotFoundError:
143
assert last_object_got is not None, f"Failed to import path {string}"
144
last_object_got = getattr(last_object_got, name)
145
return last_object_got
146
147
148
def make_source_link(cls, project_url):
149
if not hasattr(cls, "__module__"):
150
return None
151
if not project_url:
152
return None
153
154
base_module = cls.__module__.split(".")[0]
155
project_url = project_url[base_module]
156
assert project_url.endswith("/"), f"{base_module} not found"
157
project_url_version = project_url.split("/")[-2].removeprefix("v")
158
module_version = copy.copy(importlib.import_module(base_module).__version__)
159
if ".dev" in module_version:
160
module_version = project_url_version[: module_version.find(".dev")]
161
# TODO: Remove keras-rs condition, this is just a temporary thing.
162
if "keras-rs" not in project_url and module_version != project_url_version:
163
raise RuntimeError(
164
f"For project {base_module}, URL {project_url} "
165
f"has version number {project_url_version} which does not match the "
166
f"current imported package version {module_version}"
167
)
168
path = cls.__module__.replace(".", "/")
169
if base_module in ("tf_keras",):
170
path = path.replace("/src/", "/")
171
line = inspect.getsourcelines(cls)[-1]
172
return (
173
f'<span style="float:right;">'
174
f"[[source]]({project_url}{path}.py#L{line})"
175
f"</span>"
176
)
177
178
179
def code_snippet(snippet):
180
return f"```python\n{snippet}\n```\n"
181
182
183
def get_type(object_) -> str:
184
if inspect.isclass(object_):
185
return "class"
186
elif ismethod(object_):
187
return "method"
188
elif inspect.isfunction(object_):
189
return "function"
190
elif hasattr(object_, "fget"):
191
return "property"
192
else:
193
raise TypeError(
194
f"{object_} is detected as not a class, a method, "
195
f"a property, nor a function."
196
)
197
198
199
def get_name(object_) -> str:
200
if hasattr(object_, "fget"):
201
return object_.fget.__name__
202
return object_.__name__
203
204
205
def get_function_name(function):
206
if hasattr(function, "__wrapped__"):
207
return get_function_name(function.__wrapped__)
208
return function.__name__
209
210
211
def get_default_value_for_repr(value):
212
"""Return a substitute for rendering the default value of a funciton arg.
213
214
Function and object instances are rendered as <Foo object at 0x00000000>
215
which can't be parsed by black. We substitute functions with the function
216
name and objects with a rendered version of the constructor like
217
`Foo(a=2, b="bar")`.
218
219
Args:
220
value: The value to find a better rendering of.
221
222
Returns:
223
Another value or `None` if no substitution is needed.
224
"""
225
226
class ReprWrapper:
227
def __init__(self, representation):
228
self.representation = representation
229
230
def __repr__(self):
231
return self.representation
232
233
if value is inspect._empty:
234
return None
235
236
if inspect.isfunction(value):
237
# Render the function name instead
238
return ReprWrapper(value.__name__)
239
240
if inspect.isclass(value):
241
# Render classes as module.ClassName to produce a valid python
242
# dotted-name expression in the fake signature (black can parse it).
243
return ReprWrapper(value.__module__ + "." + value.__name__)
244
245
if (
246
repr(value).startswith("<") # <Foo object at 0x00000000>
247
and hasattr(value, "__class__") # it is an object
248
and hasattr(value, "get_config") # it is a Keras object
249
):
250
config = value.get_config()
251
init_args = [] # The __init__ arguments to render
252
for p in inspect.signature(value.__class__.__init__).parameters.values():
253
if p.name == "self":
254
continue
255
if p.kind == inspect.Parameter.POSITIONAL_ONLY:
256
# Required positional, render without a name
257
init_args.append(repr(config[p.name]))
258
elif p.default is inspect._empty or p.default != config[p.name]:
259
# Keyword arg with non-default value, render
260
init_args.append(p.name + "=" + repr(config[p.name]))
261
# else don't render that argument
262
return ReprWrapper(
263
value.__class__.__module__
264
+ "."
265
+ value.__class__.__name__
266
+ "("
267
+ ", ".join(init_args)
268
+ ")"
269
)
270
271
return None
272
273
274
def get_signature_start(function):
275
"""For the Dense layer, it should return the string 'keras.layers.Dense'"""
276
if ismethod(function):
277
prefix = f"{get_class_from_method(function).__name__}."
278
else:
279
try:
280
prefix = f"{function.__module__}."
281
except AttributeError:
282
warnings.warn(
283
f"function {function} has no module. "
284
f"It will not be included in the signature."
285
)
286
prefix = ""
287
return f"{prefix}{get_function_name(function)}"
288
289
290
def get_signature_end(function):
291
params = inspect.signature(function).parameters.values()
292
293
formatted_params = []
294
for x in params:
295
default = get_default_value_for_repr(x.default)
296
if default:
297
x = inspect.Parameter(
298
x.name, x.kind, default=default, annotation=x.annotation
299
)
300
str_x = str(x)
301
formatted_params.append(str_x)
302
signature_end = "(" + ", ".join(formatted_params) + ")"
303
304
if ismethod(function):
305
signature_end = signature_end.replace("(self, ", "(")
306
signature_end = signature_end.replace("(self)", "()")
307
# work around case-specific bug
308
signature_end = signature_end.replace(
309
"synchronization=<VariableSynchronization.AUTO: 0>, aggregation=<VariableAggregationV2.NONE: 0>",
310
"synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.VariableSynchronization.NONE",
311
)
312
return signature_end
313
314
315
def get_function_signature(function, override=None):
316
if override is None:
317
signature_start = get_signature_start(function)
318
else:
319
signature_start = override
320
signature_end = get_signature_end(function)
321
return format_signature(signature_start, signature_end)
322
323
324
def get_class_signature(cls, override=None):
325
if override is None:
326
signature_start = f"{cls.__module__}.{cls.__name__}"
327
else:
328
signature_start = override
329
signature_end = get_signature_end(cls.__init__)
330
return format_signature(signature_start, signature_end)
331
332
333
def get_signature(object_, override):
334
if inspect.isclass(object_):
335
return get_class_signature(object_, override)
336
elif inspect.isfunction(object_) or inspect.ismethod(object_):
337
return get_function_signature(object_, override)
338
elif hasattr(object_, "fget"):
339
# properties
340
if override:
341
return override
342
return get_function_signature(object_.fget)
343
raise ValueError(f"Not able to retrieve signature for object {object_}")
344
345
346
def format_signature(signature_start: str, signature_end: str):
347
"""pretty formatting to avoid long signatures on one single line"""
348
# first, we make it look like a real function declaration.
349
fake_signature_start = "x" * len(signature_start)
350
fake_signature = fake_signature_start + signature_end
351
fake_python_code = f"def {fake_signature}:\n pass\n"
352
# we format with black
353
mode = black.FileMode(line_length=90)
354
formatted_fake_python_code = black.format_str(fake_python_code, mode=mode)
355
# we make the final, multiline signature
356
new_signature_end = extract_signature_end(formatted_fake_python_code)
357
return signature_start + new_signature_end
358
359
360
def extract_signature_end(function_definition):
361
start = function_definition.find("(")
362
stop = function_definition.rfind(")")
363
return function_definition[start : stop + 1]
364
365
366
def get_code_blocks(docstring):
367
code_blocks = {}
368
tmp = docstring[:]
369
while "```" in tmp:
370
tmp = tmp[tmp.find("```") :]
371
index = tmp[3:].find("```") + 6
372
snippet = tmp[:index]
373
# Place marker in docstring for later reinjection.
374
# Print the index with 4 digits so we know the symbol is unique.
375
token = f"$KERAS_AUTODOC_CODE_BLOCK_{len(code_blocks):04d}"
376
docstring = docstring.replace(snippet, token)
377
code_blocks[token] = snippet
378
tmp = tmp[index:]
379
return code_blocks, docstring
380
381
382
def get_section_end(docstring, section_start):
383
regex_indented_sections_end = re.compile(r"\S\n+(\S|$)")
384
end = re.search(regex_indented_sections_end, docstring[section_start:])
385
section_end = section_start + end.end()
386
if section_end == len(docstring):
387
return section_end
388
else:
389
return section_end - 2
390
391
392
def get_examples_block(name, guides_and_examples):
393
# Prefer guides to examples, so put them first.
394
# But we otherwise keep the order, which is the order in the TOC.
395
guides = [e for e in guides_and_examples if e.is_guide]
396
examples = [e for e in guides_and_examples if not e.is_guide]
397
guides_and_examples = guides + examples
398
399
# Cap the number of links.
400
if len(guides_and_examples) > MAX_EXAMPLE_LINKS:
401
guides_and_examples = guides_and_examples[:MAX_EXAMPLE_LINKS]
402
403
# Remove module in name.
404
name = name.split(".")[-1]
405
406
return (
407
f"**Guides and examples using `{name}`**\n\n"
408
+ "\n".join([f"- [{e.title}]({e.url})" for e in guides_and_examples])
409
+ "\n"
410
)
411
412
413
def get_google_style_sections_without_code(docstring):
414
regex_indented_sections_start = re.compile(r"\n# .+?\n")
415
google_style_sections = {}
416
for i in itertools.count():
417
match = re.search(regex_indented_sections_start, docstring)
418
if match is None:
419
break
420
section_start = match.start() + 1
421
section_end = get_section_end(docstring, section_start)
422
google_style_section = docstring[section_start:section_end]
423
token = f"KERAS_AUTODOC_GOOGLE_STYLE_SECTION_{i}"
424
google_style_sections[token] = google_style_section
425
docstring = insert_in_string(docstring, token, section_start, section_end)
426
return google_style_sections, docstring
427
428
429
def get_google_style_sections(docstring):
430
# First, extract code blocks and process them.
431
# The parsing is easier if the #, : and other symbols aren't there.
432
code_blocks, docstring = get_code_blocks(docstring)
433
google_style_sections, docstring = get_google_style_sections_without_code(docstring)
434
docstring = reinject_strings(docstring, code_blocks)
435
for section_token, section in google_style_sections.items():
436
section = reinject_strings(section, code_blocks)
437
google_style_sections[section_token] = reinject_strings(section, code_blocks)
438
return google_style_sections, docstring
439
440
441
def to_markdown(google_style_section: str) -> str:
442
end_first_line = google_style_section.find("\n")
443
section_title = google_style_section[2:end_first_line]
444
section_body = google_style_section[end_first_line:]
445
section_body = remove_indentation(section_body)
446
if section_title in (
447
"Arguments",
448
"Attributes",
449
"Raises",
450
"Call arguments",
451
"Returns",
452
):
453
section_body = format_as_markdown_list(section_body)
454
if section_body:
455
return f"__{section_title}__\n\n{section_body}\n"
456
else:
457
return f"__{section_title}__\n"
458
459
460
def format_as_markdown_list(section_body):
461
section_body = re.sub(r"\n([^ ].*?):", r"\n- __\1__:", section_body)
462
section_body = re.sub(r"^([^ ].*?):", r"- __\1__:", section_body)
463
# Switch to 2-space indent so we can render nested lists.
464
section_body = section_body.replace("\n ", "\n ")
465
return section_body
466
467
468
def reinject_strings(target, strings_to_inject):
469
for token, string_to_inject in strings_to_inject.items():
470
target = target.replace(token, string_to_inject)
471
return target
472
473
474
def process_docstring(docstring):
475
if docstring[-1] != "\n":
476
docstring += "\n"
477
478
google_style_sections, docstring = get_google_style_sections(docstring)
479
for token, google_style_section in google_style_sections.items():
480
markdown_section = to_markdown(google_style_section)
481
docstring = docstring.replace(token, markdown_section)
482
return docstring
483
484
485
def get_class_from_method(meth):
486
if inspect.ismethod(meth):
487
for cls in inspect.getmro(meth.__self__.__class__):
488
if cls.__dict__.get(meth.__name__) is meth:
489
return cls
490
meth = meth.__func__ # fallback to __qualname__ parsing
491
if inspect.isfunction(meth):
492
cls_name = meth.__qualname__.split(".<locals>", 1)[0].rsplit(".", 1)[0]
493
cls = getattr(inspect.getmodule(meth), cls_name, None)
494
if isinstance(cls, type):
495
return cls
496
return getattr(meth, "__objclass__", None) # handle special descriptor objects
497
498
499
def insert_in_string(target, string_to_insert, start, end):
500
target_start_cut = target[:start]
501
target_end_cut = target[end:]
502
return target_start_cut + string_to_insert + target_end_cut
503
504
505
def remove_indentation(string):
506
lines = string.split("\n")
507
leading_spaces = [count_leading_spaces(l) for l in lines if l]
508
if leading_spaces:
509
min_leading_spaces = min(leading_spaces)
510
string = "\n".join(l[min_leading_spaces:] for l in lines)
511
return string.strip() # Drop leading/closing empty lines
512
513
514
def count_leading_spaces(s):
515
ws = re.search(r"\S", s)
516
if ws:
517
return ws.start()
518
return 0
519
520