"""Custom rendering code for keras_hub presets.
The model metadata is pulled from the library, each preset has a
metadata dictionary as follows:
{
'description': Description of the model,
'params': Parameter count of the model,
'official_name': Name of the model,
'path': Relative path of the model on keras.io,
}
"""
from hub_master import MODELS_MASTER
try:
import keras_hub
except Exception as e:
print(f"Could not import KerasHub. Exception: {e}")
keras_hub = None
TABLE_HEADER = (
"Preset | Model API | Parameters | Description\n"
"-------|-----------|------------|------------\n"
)
TABLE_HEADER_PER_MODEL = (
"Preset | Parameters | Description\n"
"-------|------------|------------\n"
)
def format_param_count(metadata):
"""Format a parameter count for the table."""
try:
count = metadata["params"]
except KeyError:
return "Unknown"
if count >= 1e9:
return f"{(count / 1e9):.2f}B"
if count >= 1e6:
return f"{(count / 1e6):.2f}M"
if count >= 1e3:
return f"{(count / 1e3):.2f}K"
return f"{count}"
def format_path(metadata):
"""Returns Path for the given preset"""
for child in MODELS_MASTER["children"]:
path = child["path"].strip("/")
if metadata["path"] == path:
text = child["title"]
link = f"/keras_hub/api/models/{path}"
return f"[{text}]({link})"
return "-"
def format_preset_link(preset, handle):
url = handle.replace("kaggle://", "https://www.kaggle.com/models/")
return f"[{preset}]({url})"
def is_base_class(symbol):
return symbol in (
keras_hub.models.Backbone,
keras_hub.models.Tokenizer,
keras_hub.models.Preprocessor,
keras_hub.models.Task,
)
def sort_presets(presets):
return sorted(
presets.keys(),
key=lambda x: (
presets[x]["metadata"]["path"],
presets[x]["metadata"]["params"],
)
)
def render_row(preset, data, add_doc_link=False):
"""Renders a row for a preset in a markdown table."""
metadata = data["metadata"]
url = data["kaggle_handle"]
url = url.replace("kaggle://", "https://www.kaggle.com/models/")
cols = []
cols.append(format_preset_link(preset, data["kaggle_handle"]))
if add_doc_link:
cols.append(format_path(metadata))
cols.append(format_param_count(metadata))
cols.append(metadata["description"])
return " | ".join(cols) + "\n"
def render_all_presets():
"""Renders the markdown table for backbone presets as a string."""
table = TABLE_HEADER
symbol = keras_hub.models.Backbone
for preset in sort_presets(symbol.presets):
data = symbol.presets[preset]
table += render_row(preset, data, add_doc_link=True)
return table
def render_table(symbol):
if keras_hub is None:
return ""
table = TABLE_HEADER_PER_MODEL
if is_base_class(symbol) or len(symbol.presets) == 0:
return None
for preset in sort_presets(symbol.presets):
data = symbol.presets[preset]
table += render_row(preset, data)
return table
def render_tags(template):
"""Replaces all custom KerasHub tags with rendered content."""
if keras_hub is None:
return template
if "{{presets_table}}" in template:
template = template.replace("{{presets_table}}", render_all_presets())
return template