Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/scripts/render_presets.py
7789 views
1
"""Custom rendering code for keras_hub presets.
2
3
The model metadata is pulled from the library, each preset has a
4
metadata dictionary as follows:
5
6
{
7
'description': Description of the model,
8
'params': Parameter count of the model,
9
'official_name': Name of the model,
10
'path': Relative path of the model on keras.io,
11
}
12
"""
13
14
from hub_master import MODELS_MASTER
15
16
try:
17
import keras_hub
18
except Exception as e:
19
print(f"Could not import KerasHub. Exception: {e}")
20
keras_hub = None
21
22
23
TABLE_HEADER = (
24
"Preset | Model API | Parameters | Description\n"
25
"-------|-----------|------------|------------\n"
26
)
27
28
TABLE_HEADER_PER_MODEL = (
29
"Preset | Parameters | Description\n" "-------|------------|------------\n"
30
)
31
32
33
def format_param_count(metadata):
34
"""Format a parameter count for the table."""
35
try:
36
count = metadata["params"]
37
except KeyError:
38
return "Unknown"
39
if count >= 1e9:
40
return f"{(count / 1e9):.2f}B"
41
if count >= 1e6:
42
return f"{(count / 1e6):.2f}M"
43
if count >= 1e3:
44
return f"{(count / 1e3):.2f}K"
45
return f"{count}"
46
47
48
def format_path(metadata):
49
"""Returns Path for the given preset"""
50
for child in MODELS_MASTER["children"]:
51
path = child["path"].strip("/")
52
if metadata["path"] == path:
53
text = child["title"]
54
link = f"/keras_hub/api/models/{path}"
55
return f"[{text}]({link})"
56
return "-"
57
58
59
def format_preset_link(preset, handle):
60
url = handle.replace("kaggle://", "https://www.kaggle.com/models/")
61
return f"[{preset}]({url})"
62
63
64
def is_base_class(symbol):
65
return symbol in (
66
keras_hub.models.Backbone,
67
keras_hub.models.Tokenizer,
68
keras_hub.models.Preprocessor,
69
keras_hub.models.Task,
70
)
71
72
73
def sort_presets(presets):
74
# Sort by path and then by parameter count.
75
return sorted(
76
presets.keys(),
77
key=lambda x: (
78
presets[x]["metadata"]["path"],
79
presets[x]["metadata"]["params"],
80
),
81
)
82
83
84
def render_row(preset, data, add_doc_link=False):
85
"""Renders a row for a preset in a markdown table."""
86
metadata = data["metadata"]
87
url = data["kaggle_handle"]
88
url = url.replace("kaggle://", "https://www.kaggle.com/models/")
89
cols = []
90
cols.append(format_preset_link(preset, data["kaggle_handle"]))
91
if add_doc_link:
92
cols.append(format_path(metadata))
93
cols.append(format_param_count(metadata))
94
cols.append(metadata["description"])
95
return " | ".join(cols) + "\n"
96
97
98
def render_all_presets():
99
"""Renders the markdown table for backbone presets as a string."""
100
table = TABLE_HEADER
101
symbol = keras_hub.models.Backbone
102
for preset in sort_presets(symbol.presets):
103
data = symbol.presets[preset]
104
table += render_row(preset, data, add_doc_link=True)
105
return table
106
107
108
def render_table(symbol):
109
if keras_hub is None:
110
return ""
111
112
table = TABLE_HEADER_PER_MODEL
113
if is_base_class(symbol) or len(symbol.presets) == 0:
114
return None
115
for preset in sort_presets(symbol.presets):
116
data = symbol.presets[preset]
117
table += render_row(preset, data)
118
return table
119
120
121
def render_tags(template):
122
"""Replaces all custom KerasHub tags with rendered content."""
123
if keras_hub is None:
124
return template
125
126
if "{{presets_table}}" in template:
127
template = template.replace("{{presets_table}}", render_all_presets())
128
return template
129
130