Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/scripts/render_presets.py
3273 views
1
"""Custom rendering code for keras_hub presets.
2
3
The model metadata is pulled from the library, each preset has a
4
metadata dictionary as follows:
5
6
{
7
'description': Description of the model,
8
'params': Parameter count of the model,
9
'official_name': Name of the model,
10
'path': Relative path of the model on keras.io,
11
}
12
"""
13
14
from hub_master import MODELS_MASTER
15
16
try:
17
import keras_hub
18
except Exception as e:
19
print(f"Could not import KerasHub. Exception: {e}")
20
keras_hub = None
21
22
23
TABLE_HEADER = (
24
"Preset | Model API | Parameters | Description\n"
25
"-------|-----------|------------|------------\n"
26
)
27
28
TABLE_HEADER_PER_MODEL = (
29
"Preset | Parameters | Description\n"
30
"-------|------------|------------\n"
31
)
32
33
34
def format_param_count(metadata):
35
"""Format a parameter count for the table."""
36
try:
37
count = metadata["params"]
38
except KeyError:
39
return "Unknown"
40
if count >= 1e9:
41
return f"{(count / 1e9):.2f}B"
42
if count >= 1e6:
43
return f"{(count / 1e6):.2f}M"
44
if count >= 1e3:
45
return f"{(count / 1e3):.2f}K"
46
return f"{count}"
47
48
49
def format_path(metadata):
50
"""Returns Path for the given preset"""
51
for child in MODELS_MASTER["children"]:
52
path = child["path"].strip("/")
53
if metadata["path"] == path:
54
text = child["title"]
55
link = f"/keras_hub/api/models/{path}"
56
return f"[{text}]({link})"
57
return "-"
58
59
60
def format_preset_link(preset, handle):
61
url = handle.replace("kaggle://", "https://www.kaggle.com/models/")
62
return f"[{preset}]({url})"
63
64
65
def is_base_class(symbol):
66
return symbol in (
67
keras_hub.models.Backbone,
68
keras_hub.models.Tokenizer,
69
keras_hub.models.Preprocessor,
70
keras_hub.models.Task,
71
)
72
73
74
def sort_presets(presets):
75
# Sort by path and then by parameter count.
76
return sorted(
77
presets.keys(),
78
key=lambda x: (
79
presets[x]["metadata"]["path"],
80
presets[x]["metadata"]["params"],
81
)
82
)
83
84
85
def render_row(preset, data, add_doc_link=False):
86
"""Renders a row for a preset in a markdown table."""
87
metadata = data["metadata"]
88
url = data["kaggle_handle"]
89
url = url.replace("kaggle://", "https://www.kaggle.com/models/")
90
cols = []
91
cols.append(format_preset_link(preset, data["kaggle_handle"]))
92
if add_doc_link:
93
cols.append(format_path(metadata))
94
cols.append(format_param_count(metadata))
95
cols.append(metadata["description"])
96
return " | ".join(cols) + "\n"
97
98
99
def render_all_presets():
100
"""Renders the markdown table for backbone presets as a string."""
101
table = TABLE_HEADER
102
symbol = keras_hub.models.Backbone
103
for preset in sort_presets(symbol.presets):
104
data = symbol.presets[preset]
105
table += render_row(preset, data, add_doc_link=True)
106
return table
107
108
109
def render_table(symbol):
110
if keras_hub is None:
111
return ""
112
113
table = TABLE_HEADER_PER_MODEL
114
if is_base_class(symbol) or len(symbol.presets) == 0:
115
return None
116
for preset in sort_presets(symbol.presets):
117
data = symbol.presets[preset]
118
table += render_row(preset, data)
119
return table
120
121
122
def render_tags(template):
123
"""Replaces all custom KerasHub tags with rendered content."""
124
if keras_hub is None:
125
return template
126
127
if "{{presets_table}}" in template:
128
template = template.replace("{{presets_table}}", render_all_presets())
129
return template
130
131