Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/scripts/docstrings.py
3273 views
1
"""Lightweight fork of Keras-Autodocs.
2
"""
3
4
import warnings
5
import black
6
import re
7
import inspect
8
import importlib
9
import itertools
10
import copy
11
12
import render_presets
13
14
15
class KerasDocumentationGenerator:
16
def __init__(self, project_url=None):
17
self.project_url = project_url
18
19
def process_docstring(self, docstring):
20
docstring = docstring.replace("Args:", "# Arguments")
21
docstring = docstring.replace("Arguments:", "# Arguments")
22
docstring = docstring.replace("Attributes:", "# Attributes")
23
docstring = docstring.replace("Returns:", "# Returns")
24
docstring = docstring.replace("Raises:", "# Raises")
25
docstring = docstring.replace("Input shape:", "# Input shape")
26
docstring = docstring.replace("Output shape:", "# Output shape")
27
docstring = docstring.replace("Call arguments:", "# Call arguments")
28
docstring = docstring.replace("Returns:", "# Returns")
29
docstring = docstring.replace("Example:", "# Example\n")
30
docstring = docstring.replace("Examples:", "# Examples\n")
31
32
docstring = re.sub(r"\nReference:\n\s*", "\n**Reference**\n\n", docstring)
33
docstring = re.sub(r"\nReferences:\n\s*", "\n**References**\n\n", docstring)
34
35
# Fix typo
36
docstring = docstring.replace("\n >>> ", "\n>>> ")
37
38
lines = docstring.split("\n")
39
doctest_lines = []
40
usable_lines = []
41
42
def flush_docstest(usable_lines, doctest_lines):
43
usable_lines.append("```python")
44
usable_lines += doctest_lines
45
usable_lines.append("```")
46
usable_lines.append("")
47
48
for line in lines:
49
if doctest_lines:
50
if not line or set(line) == {" "}:
51
flush_docstest(usable_lines, doctest_lines)
52
doctest_lines = []
53
else:
54
doctest_lines.append(line)
55
else:
56
if line.startswith(">>>"):
57
doctest_lines.append(line)
58
else:
59
usable_lines.append(line)
60
if doctest_lines:
61
flush_docstest(usable_lines, doctest_lines)
62
docstring = "\n".join(usable_lines)
63
64
return process_docstring(docstring)
65
66
def process_signature(self, signature):
67
signature = signature.replace("tensorflow.keras", "tf.keras")
68
signature = signature.replace("*args, **kwargs", "")
69
return signature
70
71
def render(self, element):
72
if isinstance(element, str):
73
object_ = import_object(element)
74
if ismethod(object_):
75
# we remove the modules when displaying the methods
76
signature_override = ".".join(element.split(".")[-2:])
77
else:
78
signature_override = element
79
else:
80
signature_override = None
81
object_ = element
82
return self.render_from_object(object_, signature_override, element)
83
84
def render_from_object(self, object_, signature_override: str, element):
85
subblocks = []
86
source_link = make_source_link(object_, self.project_url)
87
if source_link is not None:
88
subblocks.append(source_link)
89
signature = get_signature(object_, signature_override)
90
signature = self.process_signature(signature)
91
subblocks.append(f"### `{get_name(object_)}` {get_type(object_)}\n")
92
subblocks.append(code_snippet(signature))
93
94
docstring = inspect.getdoc(object_)
95
if docstring:
96
docstring = self.process_docstring(docstring)
97
subblocks.append(docstring)
98
# Render preset table for KerasCV and KerasHub
99
if element.endswith("from_preset"):
100
table = render_presets.render_table(import_object(element.rsplit(".", 1)[0]))
101
if table is not None:
102
subblocks.append(table)
103
return "\n\n".join(subblocks) + "\n\n----\n\n"
104
105
106
def ismethod(function):
107
return get_class_from_method(function) is not None
108
109
110
def import_object(string: str):
111
"""Import an object from a string.
112
113
The object can be a function, class or method.
114
For example: `'keras.layers.Dense.get_weights'` is valid.
115
"""
116
last_object_got = None
117
seen_names = []
118
for name in string.split("."):
119
seen_names.append(name)
120
try:
121
last_object_got = importlib.import_module(".".join(seen_names))
122
except ModuleNotFoundError:
123
assert last_object_got is not None, f"Failed to import path {string}"
124
last_object_got = getattr(last_object_got, name)
125
return last_object_got
126
127
128
def make_source_link(cls, project_url):
129
if not hasattr(cls, "__module__"):
130
return None
131
if not project_url:
132
return None
133
134
base_module = cls.__module__.split(".")[0]
135
project_url = project_url[base_module]
136
assert project_url.endswith("/"), f"{base_module} not found"
137
project_url_version = project_url.split("/")[-2].removeprefix("v")
138
module_version = copy.copy(importlib.import_module(base_module).__version__)
139
if ".dev" in module_version:
140
module_version = project_url_version[: module_version.find(".dev")]
141
# TODO: Remove keras-rs condition, this is just a temporary thing.
142
if "keras-rs" not in project_url and module_version != project_url_version:
143
raise RuntimeError(
144
f"For project {base_module}, URL {project_url} "
145
f"has version number {project_url_version} which does not match the "
146
f"current imported package version {module_version}"
147
)
148
path = cls.__module__.replace(".", "/")
149
if base_module in ("tf_keras",):
150
path = path.replace("/src/", "/")
151
line = inspect.getsourcelines(cls)[-1]
152
return (
153
f'<span style="float:right;">'
154
f"[[source]]({project_url}{path}.py#L{line})"
155
f"</span>"
156
)
157
158
159
def code_snippet(snippet):
160
return f"```python\n{snippet}\n```\n"
161
162
163
def get_type(object_) -> str:
164
if inspect.isclass(object_):
165
return "class"
166
elif ismethod(object_):
167
return "method"
168
elif inspect.isfunction(object_):
169
return "function"
170
elif hasattr(object_, "fget"):
171
return "property"
172
else:
173
raise TypeError(
174
f"{object_} is detected as not a class, a method, "
175
f"a property, nor a function."
176
)
177
178
179
def get_name(object_) -> str:
180
if hasattr(object_, "fget"):
181
return object_.fget.__name__
182
return object_.__name__
183
184
185
def get_function_name(function):
186
if hasattr(function, "__wrapped__"):
187
return get_function_name(function.__wrapped__)
188
return function.__name__
189
190
191
def get_default_value_for_repr(value):
192
"""Return a substitute for rendering the default value of a funciton arg.
193
194
Function and object instances are rendered as <Foo object at 0x00000000>
195
which can't be parsed by black. We substitute functions with the function
196
name and objects with a rendered version of the constructor like
197
`Foo(a=2, b="bar")`.
198
199
Args:
200
value: The value to find a better rendering of.
201
202
Returns:
203
Another value or `None` if no substitution is needed.
204
"""
205
206
class ReprWrapper:
207
def __init__(self, representation):
208
self.representation = representation
209
210
def __repr__(self):
211
return self.representation
212
213
if value is inspect._empty:
214
return None
215
216
if inspect.isfunction(value):
217
# Render the function name instead
218
return ReprWrapper(value.__name__)
219
220
if (
221
repr(value).startswith("<") # <Foo object at 0x00000000>
222
and hasattr(value, "__class__") # it is an object
223
and hasattr(value, "get_config") # it is a Keras object
224
):
225
config = value.get_config()
226
init_args = [] # The __init__ arguments to render
227
for p in inspect.signature(value.__class__.__init__).parameters.values():
228
if p.name == "self":
229
continue
230
if p.kind == inspect.Parameter.POSITIONAL_ONLY:
231
# Required positional, render without a name
232
init_args.append(repr(config[p.name]))
233
elif p.default is inspect._empty or p.default != config[p.name]:
234
# Keyword arg with non-default value, render
235
init_args.append(p.name + "=" + repr(config[p.name]))
236
# else don't render that argument
237
return ReprWrapper(
238
value.__class__.__module__
239
+ "."
240
+ value.__class__.__name__
241
+ "("
242
+ ", ".join(init_args)
243
+ ")"
244
)
245
246
return None
247
248
249
def get_signature_start(function):
250
"""For the Dense layer, it should return the string 'keras.layers.Dense'"""
251
if ismethod(function):
252
prefix = f"{get_class_from_method(function).__name__}."
253
else:
254
try:
255
prefix = f"{function.__module__}."
256
except AttributeError:
257
warnings.warn(
258
f"function {function} has no module. "
259
f"It will not be included in the signature."
260
)
261
prefix = ""
262
return f"{prefix}{get_function_name(function)}"
263
264
265
def get_signature_end(function):
266
params = inspect.signature(function).parameters.values()
267
268
formatted_params = []
269
for x in params:
270
default = get_default_value_for_repr(x.default)
271
if default:
272
x = inspect.Parameter(
273
x.name, x.kind, default=default, annotation=x.annotation
274
)
275
str_x = str(x)
276
formatted_params.append(str_x)
277
signature_end = "(" + ", ".join(formatted_params) + ")"
278
279
if ismethod(function):
280
signature_end = signature_end.replace("(self, ", "(")
281
signature_end = signature_end.replace("(self)", "()")
282
# work around case-specific bug
283
signature_end = signature_end.replace(
284
"synchronization=<VariableSynchronization.AUTO: 0>, aggregation=<VariableAggregationV2.NONE: 0>",
285
"synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.VariableSynchronization.NONE",
286
)
287
return signature_end
288
289
290
def get_function_signature(function, override=None):
291
if override is None:
292
signature_start = get_signature_start(function)
293
else:
294
signature_start = override
295
signature_end = get_signature_end(function)
296
return format_signature(signature_start, signature_end)
297
298
299
def get_class_signature(cls, override=None):
300
if override is None:
301
signature_start = f"{cls.__module__}.{cls.__name__}"
302
else:
303
signature_start = override
304
signature_end = get_signature_end(cls.__init__)
305
return format_signature(signature_start, signature_end)
306
307
308
def get_signature(object_, override):
309
if inspect.isclass(object_):
310
return get_class_signature(object_, override)
311
elif inspect.isfunction(object_) or inspect.ismethod(object_):
312
return get_function_signature(object_, override)
313
elif hasattr(object_, "fget"):
314
# properties
315
if override:
316
return override
317
return get_function_signature(object_.fget)
318
raise ValueError(f"Not able to retrieve signature for object {object_}")
319
320
321
def format_signature(signature_start: str, signature_end: str):
322
"""pretty formatting to avoid long signatures on one single line"""
323
# first, we make it look like a real function declaration.
324
fake_signature_start = "x" * len(signature_start)
325
fake_signature = fake_signature_start + signature_end
326
fake_python_code = f"def {fake_signature}:\n pass\n"
327
# we format with black
328
mode = black.FileMode(line_length=90)
329
formatted_fake_python_code = black.format_str(fake_python_code, mode=mode)
330
# we make the final, multiline signature
331
new_signature_end = extract_signature_end(formatted_fake_python_code)
332
return signature_start + new_signature_end
333
334
335
def extract_signature_end(function_definition):
336
start = function_definition.find("(")
337
stop = function_definition.rfind(")")
338
return function_definition[start : stop + 1]
339
340
341
def get_code_blocks(docstring):
342
code_blocks = {}
343
tmp = docstring[:]
344
while "```" in tmp:
345
tmp = tmp[tmp.find("```") :]
346
index = tmp[3:].find("```") + 6
347
snippet = tmp[:index]
348
# Place marker in docstring for later reinjection.
349
# Print the index with 4 digits so we know the symbol is unique.
350
token = f"$KERAS_AUTODOC_CODE_BLOCK_{len(code_blocks):04d}"
351
docstring = docstring.replace(snippet, token)
352
code_blocks[token] = snippet
353
tmp = tmp[index:]
354
return code_blocks, docstring
355
356
357
def get_section_end(docstring, section_start):
358
regex_indented_sections_end = re.compile(r"\S\n+(\S|$)")
359
end = re.search(regex_indented_sections_end, docstring[section_start:])
360
section_end = section_start + end.end()
361
if section_end == len(docstring):
362
return section_end
363
else:
364
return section_end - 2
365
366
367
def get_google_style_sections_without_code(docstring):
368
regex_indented_sections_start = re.compile(r"\n# .+?\n")
369
google_style_sections = {}
370
for i in itertools.count():
371
match = re.search(regex_indented_sections_start, docstring)
372
if match is None:
373
break
374
section_start = match.start() + 1
375
section_end = get_section_end(docstring, section_start)
376
google_style_section = docstring[section_start:section_end]
377
token = f"KERAS_AUTODOC_GOOGLE_STYLE_SECTION_{i}"
378
google_style_sections[token] = google_style_section
379
docstring = insert_in_string(docstring, token, section_start, section_end)
380
return google_style_sections, docstring
381
382
383
def get_google_style_sections(docstring):
384
# First, extract code blocks and process them.
385
# The parsing is easier if the #, : and other symbols aren't there.
386
code_blocks, docstring = get_code_blocks(docstring)
387
google_style_sections, docstring = get_google_style_sections_without_code(docstring)
388
docstring = reinject_strings(docstring, code_blocks)
389
for section_token, section in google_style_sections.items():
390
section = reinject_strings(section, code_blocks)
391
google_style_sections[section_token] = reinject_strings(section, code_blocks)
392
return google_style_sections, docstring
393
394
395
def to_markdown(google_style_section: str) -> str:
396
end_first_line = google_style_section.find("\n")
397
section_title = google_style_section[2:end_first_line]
398
section_body = google_style_section[end_first_line:]
399
section_body = remove_indentation(section_body)
400
if section_title in (
401
"Arguments",
402
"Attributes",
403
"Raises",
404
"Call arguments",
405
"Returns",
406
):
407
section_body = format_as_markdown_list(section_body)
408
if section_body:
409
return f"__{section_title}__\n\n{section_body}\n"
410
else:
411
return f"__{section_title}__\n"
412
413
414
def format_as_markdown_list(section_body):
415
section_body = re.sub(r"\n([^ ].*?):", r"\n- __\1__:", section_body)
416
section_body = re.sub(r"^([^ ].*?):", r"- __\1__:", section_body)
417
# Switch to 2-space indent so we can render nested lists.
418
section_body = section_body.replace("\n ", "\n ")
419
return section_body
420
421
422
def reinject_strings(target, strings_to_inject):
423
for token, string_to_inject in strings_to_inject.items():
424
target = target.replace(token, string_to_inject)
425
return target
426
427
428
def process_docstring(docstring):
429
if docstring[-1] != "\n":
430
docstring += "\n"
431
432
google_style_sections, docstring = get_google_style_sections(docstring)
433
for token, google_style_section in google_style_sections.items():
434
markdown_section = to_markdown(google_style_section)
435
docstring = docstring.replace(token, markdown_section)
436
return docstring
437
438
439
def get_class_from_method(meth):
440
if inspect.ismethod(meth):
441
for cls in inspect.getmro(meth.__self__.__class__):
442
if cls.__dict__.get(meth.__name__) is meth:
443
return cls
444
meth = meth.__func__ # fallback to __qualname__ parsing
445
if inspect.isfunction(meth):
446
cls_name = meth.__qualname__.split(".<locals>", 1)[0].rsplit(".", 1)[0]
447
cls = getattr(inspect.getmodule(meth), cls_name, None)
448
if isinstance(cls, type):
449
return cls
450
return getattr(meth, "__objclass__", None) # handle special descriptor objects
451
452
453
def insert_in_string(target, string_to_insert, start, end):
454
target_start_cut = target[:start]
455
target_end_cut = target[end:]
456
return target_start_cut + string_to_insert + target_end_cut
457
458
459
def remove_indentation(string):
460
lines = string.split("\n")
461
leading_spaces = [count_leading_spaces(l) for l in lines if l]
462
if leading_spaces:
463
min_leading_spaces = min(leading_spaces)
464
string = "\n".join(l[min_leading_spaces:] for l in lines)
465
return string.strip() # Drop leading/closing empty lines
466
467
468
def count_leading_spaces(s):
469
ws = re.search(r"\S", s)
470
if ws:
471
return ws.start()
472
return 0
473
474