"""
Generate Python WebDriver BiDi command modules from CDDL specification.
This generator reads CDDL (Concise Data Definition Language) specification files
and produces Python type definitions and command classes that conform to the
WebDriver BiDi protocol.
Usage:
python generate_bidi.py <cddl_file> <output_dir> <spec_version>
Example:
python generate_bidi.py local.cddl ./selenium/webdriver/common/bidi 1.0
"""
import argparse
import importlib.util
import logging
import re
import sys
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from textwrap import indent as tw_indent
from typing import Any
__version__ = "1.0.0"
log_level = logging.INFO
logging.basicConfig(level=log_level)
logger = logging.getLogger("generate_bidi")
SHARED_HEADER = """# DO NOT EDIT THIS FILE!
#
# This file is generated from the WebDriver BiDi specification. If you need to make
# changes, edit the generator and regenerate all of the modules."""
_MODULE_HEADER_COMMENTS = f"""{SHARED_HEADER}
#
# WebDriver BiDi module: {{}}
"""
_MODULE_HEADER_IMPORTS = "from __future__ import annotations\n\n"
def indent(s: str, n: int) -> str:
"""Indent a string by n spaces."""
return tw_indent(s, n * " ")
def _docstring_text(custom: str | None, fallback_name: str, fallback_desc: str = "") -> str:
"""Select the appropriate raw docstring text (no triple-quotes).
Priority: manifest custom string > CDDL description (if different from name) > 'ClassName.'
"""
if custom:
return custom.strip()
if fallback_desc and fallback_desc != fallback_name:
return fallback_desc
return f"{fallback_name}."
def _emit_docstring(text: str, indent_width: int) -> str:
r"""Produce a PEP 257-compliant docstring block with a trailing newline.
Single-line output: <indent>\"\"\"text.\"\"\"\n
Multi-line output:
<indent>\"\"\"
<indent>First line.
<indent>
<indent>Continuation.
<indent>\"\"\"
The opening and closing triple-quotes each occupy their own line for
multi-line strings so that inspect.getdoc() / Sphinx dedent cleanly.
"""
prefix = " " * indent_width
stripped = text.strip()
lines = stripped.splitlines()
if len(lines) <= 1:
return f'{prefix}"""{stripped}"""\n'
content = "\n".join(f"{prefix}{line}" if line.strip() else "" for line in lines)
return f'{prefix}"""\n{content}\n{prefix}"""\n'
def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]:
"""Load enhancement manifest from a Python file.
Args:
manifest_path: Path to Python file containing ENHANCEMENTS dict
Returns:
Dictionary with enhancement rules, or empty dict if no manifest provided
"""
if not manifest_path:
return {}
manifest_file = Path(manifest_path)
if not manifest_file.exists():
logger.warning(f"Enhancement manifest not found: {manifest_path}")
return {}
try:
spec = importlib.util.spec_from_file_location("bidi_enhancements", manifest_file)
if spec is None or spec.loader is None:
logger.warning(f"Could not load manifest: {manifest_path}")
return {}
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
enhancements = getattr(module, "ENHANCEMENTS", {})
dataclass_methods = getattr(module, "DATACLASS_METHOD_TEMPLATES", {})
method_docstrings = getattr(module, "DATACLASS_METHOD_DOCSTRINGS", {})
logger.info(f"Loaded enhancement manifest from: {manifest_path}")
logger.debug(f"Enhancements for modules: {list(enhancements.keys())}")
return {
"enhancements": enhancements,
"dataclass_methods": dataclass_methods,
"method_docstrings": method_docstrings,
}
except Exception as e:
logger.error(f"Failed to load enhancement manifest: {e}", exc_info=True)
return {}
class CddlType(Enum):
"""CDDL type mappings to Python types."""
TSTR = "str"
TEXT = "str"
UINT = "int"
INT = "int"
NINT = "int"
BOOL = "bool"
NULL = "None"
ANY = "Any"
@classmethod
def get_annotation(cls, cddl_type: str) -> str:
"""Get Python type annotation for a CDDL type."""
cddl_type = cddl_type.strip().lower()
for member in cls:
if cddl_type == member.name.lower():
return member.value
if cddl_type.startswith("["):
inner = cddl_type.strip("[]+ ")
inner_type = cls.get_annotation(inner)
return f"list[{inner_type}]"
if cddl_type.startswith("{"):
return "dict[str, Any]"
return "Any"
@dataclass
class CddlCommand:
"""Represents a CDDL command definition."""
module: str
name: str
params: dict[str, str] = field(default_factory=dict)
required_params: set[str] = field(default_factory=set)
result: str | None = None
description: str = ""
def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python method code for this command.
Args:
enhancements: Dictionary with enhancement rules for this method
"""
enhancements = enhancements or {}
method_name = self._camel_to_snake(self.name)
params_to_use = self.params
if "params_override" in enhancements:
params_to_use = enhancements["params_override"]
param_strs = []
param_names = []
for param_name, param_type in params_to_use.items():
if param_type in ["bool", "str", "int"]:
python_type = param_type
else:
python_type = CddlType.get_annotation(param_type)
snake_param = self._camel_to_snake(param_name)
param_names.append((param_name, snake_param))
param_strs.append(f"{snake_param}: {python_type} | None = None")
if param_strs:
single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):"
if len(single_line_signature) > 120:
body = f" def {method_name}(\n"
body += " self,\n"
for i, param_str in enumerate(param_strs):
if i < len(param_strs) - 1:
body += f" {param_str},\n"
else:
body += f" {param_str},\n"
body += " ):\n"
else:
param_list = "self, " + ", ".join(param_strs)
body = f" def {method_name}({param_list}):\n"
else:
body = f" def {method_name}(self):\n"
docstring = enhancements.get("docstring") or self.description or f"Execute {self.module}.{self.name}."
body += _emit_docstring(docstring, 8)
if self.required_params:
method_snake = self._camel_to_snake(self.name)
for param_name, snake_param in param_names:
if param_name in self.required_params:
body += f" if {snake_param} is None:\n"
msg = f"{method_snake}() missing required argument:"
error_message = f"{msg} {snake_param!r}"
body += f" raise TypeError({error_message!r})\n"
body += "\n"
if "validate" in enhancements:
validate_func = enhancements["validate"]
param_args = ", ".join(f"{snake}={snake}" for _, snake in param_names)
body += f" {validate_func}({param_args})\n"
body += "\n"
if "transform" in enhancements:
transform_spec = enhancements["transform"]
if isinstance(transform_spec, dict):
transform_func = transform_spec.get("func")
result_param = transform_spec.get("result_param", "params")
input_params = [
transform_spec.get(k) for k in ["allowed", "destination_folder"] if transform_spec.get(k)
]
if transform_func and result_param:
body += f" {result_param} = None\n"
param_args = ", ".join(input_params)
body += f" {result_param} = {transform_func}({param_args})\n"
body += "\n"
else:
transform_func = transform_spec
if self.name == "setDownloadBehavior":
body += " download_behavior = None\n"
body += f" download_behavior = {transform_func}(allowed, destination_folder)\n"
body += "\n"
if "preprocess" in enhancements:
preprocess_rules = enhancements["preprocess"]
for param_name, preprocess_type in preprocess_rules.items():
snake_param = self._camel_to_snake(param_name)
if preprocess_type == "check_serialize_method":
body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n"
body += f" {snake_param} = {snake_param}.to_bidi_dict()\n"
body += "\n"
body += " params = {\n"
if "transform" in enhancements and isinstance(enhancements["transform"], dict):
transform_spec = enhancements["transform"]
result_param = transform_spec.get("result_param")
if result_param == "download_behavior":
body += ' "downloadBehavior": download_behavior,\n'
for cddl_param_name in self.params:
if cddl_param_name not in ["downloadBehavior"]:
snake_name = self._camel_to_snake(cddl_param_name)
body += f' "{cddl_param_name}": {snake_name},\n'
else:
for param_name, snake_param in param_names:
body += f' "{param_name}": {snake_param},\n'
body += " }\n"
body += " params = {k: v for k, v in params.items() if v is not None}\n"
body += f' cmd = command_builder("{self.module}.{self.name}", params)\n'
body += " result = self._conn.execute(cmd)\n"
if "extract_field" in enhancements:
extract_field = enhancements["extract_field"]
extract_property = enhancements.get("extract_property")
deserialize_rules = enhancements.get("deserialize", {})
if extract_property:
body += f' if result and "{extract_field}" in result:\n'
body += f' items = result.get("{extract_field}", [])\n'
body += " return [\n"
body += f' item.get("{extract_property}")\n'
body += " for item in items\n"
body += " if isinstance(item, dict)\n"
body += " ]\n"
body += " return []\n"
elif extract_field in deserialize_rules:
type_name = deserialize_rules[extract_field]
body += f' if result and "{extract_field}" in result:\n'
body += f' items = result.get("{extract_field}", [])\n'
body += " return [\n"
body += f" {type_name}(\n"
body += self._generate_field_args(extract_field, type_name)
body += " )\n"
body += " for item in items\n"
body += " if isinstance(item, dict)\n"
body += " ]\n"
body += " return []\n"
else:
body += f' if result and "{extract_field}" in result:\n'
body += f' extracted = result.get("{extract_field}")\n'
body += " return extracted\n"
body += " return result\n"
elif "deserialize" in enhancements:
deserialize_rules = enhancements["deserialize"]
for response_field, type_name in deserialize_rules.items():
body += f' if result and "{response_field}" in result:\n'
body += f' items = result.get("{response_field}", [])\n'
body += " return [\n"
body += f" {type_name}(\n"
body += self._generate_field_args(response_field, type_name)
body += " )\n"
body += " for item in items\n"
body += " if isinstance(item, dict)\n"
body += " ]\n"
body += " return []\n"
else:
body += " return result\n"
return body
def _generate_field_args(self, response_field: str, type_name: str) -> str:
"""Generate constructor arguments for deserializing response objects.
For now, this handles ClientWindowInfo and Info specifically.
Could be extended to be more generic.
"""
if type_name == "ClientWindowInfo":
return (
' active=item.get("active"),\n'
' client_window=item.get("clientWindow"),\n'
' height=item.get("height"),\n'
' state=item.get("state"),\n'
' width=item.get("width"),\n'
' x=item.get("x"),\n'
' y=item.get("y")\n'
)
elif type_name == "Info":
return (
' children=_deserialize_info_list(item.get("children", [])),\n'
' client_window=item.get("clientWindow"),\n'
' context=item.get("context"),\n'
' original_opener=item.get("originalOpener"),\n'
' url=item.get("url"),\n'
' user_context=item.get("userContext"),\n'
' parent=item.get("parent")\n'
)
return ""
@staticmethod
def _camel_to_snake(name: str) -> str:
"""Convert camelCase to snake_case."""
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
@dataclass
class CddlTypeDefinition:
"""Represents a CDDL type definition."""
module: str
name: str
fields: dict[str, str] = field(default_factory=dict)
description: str = ""
def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python dataclass code for this type.
Args:
enhancements: Dictionary containing dataclass_methods and method_docstrings
"""
enhancements = enhancements or {}
dataclass_methods = enhancements.get("dataclass_methods", {})
method_docstrings = enhancements.get("method_docstrings", {})
class_name = self.name
code = "@dataclass\n"
code += f"class {class_name}:\n"
class_docstrings = enhancements.get("class_docstrings", {})
class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description)
code += _emit_docstring(class_doc, 4)
code += "\n"
if not self.fields:
code += " pass\n"
else:
for field_name, field_type in self.fields.items():
python_type = self._get_python_type(field_type)
snake_name = CddlCommand._camel_to_snake(field_name)
literal_match = re.match(r'^"([^"]+)"$', field_type.strip())
if literal_match:
literal_value = literal_match.group(1)
code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n'
elif python_type.startswith("list["):
type_annotation = python_type.replace(" | None", "")
code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n"
elif python_type.startswith("dict["):
type_annotation = python_type.replace(" | None", "")
code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n"
else:
code += f" {snake_name}: {python_type} = None\n"
if class_name in dataclass_methods:
code += "\n"
methods_dict = dataclass_methods[class_name]
docstrings_dict = method_docstrings.get(class_name, {})
for method_name in methods_dict:
method_impl = methods_dict[method_name]
docstring = docstrings_dict.get(method_name, "")
code += f" def {method_name}(self):\n"
if docstring:
code += f' """{docstring}"""\n'
code += f" {method_impl}\n"
code += "\n"
return code
@staticmethod
def _get_python_type(cddl_type: str) -> str:
"""Convert CDDL type to Python type annotation using Python 3.10+ syntax."""
cddl_type = cddl_type.strip().lower()
type_mapping = {
"tstr": "str",
"text": "str",
"uint": "int",
"int": "int",
"nint": "int",
"bool": "bool",
"null": "None",
}
for cddl, python in type_mapping.items():
if cddl_type == cddl:
return f"{python} | None"
if cddl_type.startswith("["):
inner = cddl_type.strip("[]+ ")
inner_type = CddlTypeDefinition._get_python_type(inner)
if " | None" in inner_type:
inner_base = inner_type.replace(" | None", "")
return f"list[{inner_base} | None] | None"
return f"list[{inner_type}] | None"
if cddl_type.startswith("{"):
return "dict[str, Any] | None"
return "Any | None"
@dataclass
class CddlEnum:
"""Represents a CDDL enum definition (string union)."""
module: str
name: str
values: list[str] = field(default_factory=list)
description: str = ""
def to_python_class(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python enum class code.
Generates a simple class with string constants to match the existing
pattern in the codebase (e.g., ClientWindowState).
"""
enhancements = enhancements or {}
class_name = self.name
class_docstrings = enhancements.get("class_docstrings", {})
class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description)
code = f"class {class_name}:\n"
code += _emit_docstring(class_doc, 4)
code += "\n"
for value in self.values:
const_name = self._value_to_const_name(value)
code += f' {const_name} = "{value}"\n'
return code
@staticmethod
def _value_to_const_name(value: str) -> str:
"""Convert enum string value to constant name.
Examples:
"none" -> "NONE"
"portrait-primary" -> "PORTRAIT_PRIMARY"
"interactive" -> "INTERACTIVE"
"""
const_name = value.replace("-", "_")
return const_name.upper()
@dataclass
class CddlEvent:
"""Represents a CDDL event definition (incoming message from browser)."""
module: str
name: str
method: str
params_type: str
description: str = ""
def to_python_dataclass(self) -> str:
"""Generate Python dataclass code for the event info type.
Returns a dataclass code that attempts to use globals().get() for safety.
"""
class_name = self.name
type_name = self.params_type.split(".")[-1] if "." in self.params_type else self.params_type
if type_name == "NavigationInfo":
type_name = "BaseNavigationInfo"
code = f"# Event: {self.method}\n"
code += f"{class_name} = globals().get('{type_name}', dict) # Fallback to dict if type not defined\n"
return code
@dataclass
class CddlModule:
"""Represents a CDDL module (e.g., script, network, browsing_context)."""
name: str
commands: list[CddlCommand] = field(default_factory=list)
types: list[CddlTypeDefinition] = field(default_factory=list)
enums: list[CddlEnum] = field(default_factory=list)
events: list[CddlEvent] = field(default_factory=list)
@staticmethod
def _convert_method_to_event_name(method_suffix: str) -> str:
"""Convert BiDi method suffix to friendly event name.
Examples:
"contextCreated" -> "context_created"
"navigationStarted" -> "navigation_started"
"userPromptOpened" -> "user_prompt_opened"
"""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def generate_code(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python code for this module.
Args:
enhancements: Dictionary with module-level enhancements
"""
enhancements = enhancements or {}
module_docstring = enhancements.get("module_docstring", "")
code = _MODULE_HEADER_COMMENTS.format(self.name)
if module_docstring:
code += _emit_docstring(module_docstring, 0) + "\n"
code += _MODULE_HEADER_IMPORTS
needs_command_builder = bool(self.commands)
needs_dataclass = self.commands or self.types or self.events
needs_callable = self.events
stdlib_imports = []
local_imports = []
if needs_callable:
stdlib_imports.append("from collections.abc import Callable")
if needs_dataclass:
stdlib_imports.append("from dataclasses import dataclass")
stdlib_imports.append("from typing import Any")
if needs_command_builder:
local_imports.append("from selenium.webdriver.common.bidi.common import command_builder")
if self.events:
local_imports.append(
"from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager"
)
code += "\n".join(stdlib_imports) + "\n"
if local_imports:
code += "\n" + "\n".join(local_imports) + "\n"
code += "\n"
helper_funcs_to_add = set()
for cmd in self.commands:
method_name_snake = cmd._camel_to_snake(cmd.name)
method_enhancements = enhancements.get(method_name_snake, {})
if "validate" in method_enhancements:
helper_funcs_to_add.add(("validate", method_enhancements["validate"]))
if "transform" in method_enhancements and isinstance(method_enhancements["transform"], dict):
transform_spec = method_enhancements["transform"]
if "func" in transform_spec:
helper_funcs_to_add.add(("transform", transform_spec["func"]))
if helper_funcs_to_add:
for func_type, func_name in sorted(helper_funcs_to_add):
if func_type == "validate" and func_name == "validate_download_behavior":
code += """def validate_download_behavior(
allowed: bool | None,
destination_folder: str | None,
user_contexts: Any | None = None,
) -> None:
\"\"\"Validate download behavior parameters.
Args:
allowed: Whether downloads are allowed
destination_folder: Destination folder for downloads
user_contexts: Optional list of user contexts
Raises:
ValueError: If parameters are invalid
\"\"\"
if allowed is True and not destination_folder:
raise ValueError("destination_folder is required when allowed=True")
if allowed is False and destination_folder:
raise ValueError("destination_folder should not be provided when allowed=False")
"""
elif func_type == "transform" and func_name == "transform_download_params":
code += """def transform_download_params(
allowed: bool | None,
destination_folder: str | None,
) -> dict[str, Any] | None:
\"\"\"Transform download parameters into download_behavior object.
Args:
allowed: Whether downloads are allowed
destination_folder: Destination folder for downloads (accepts str or
pathlib.Path; will be coerced to str)
Returns:
Dictionary representing the download_behavior object, or None if allowed is None
\"\"\"
if allowed is True:
return {
"type": "allowed",
# Coerce pathlib.Path (or any path-like) to str so the BiDi
# protocol always receives a plain JSON string.
"destinationFolder": str(destination_folder) if destination_folder is not None else None,
}
elif allowed is False:
return {"type": "denied"}
else: # None — reset to browser default (sent as JSON null)
return None
"""
exclude_types = set(enhancements.get("exclude_types", []))
for extra_cls in enhancements.get("extra_dataclasses", []):
match = re.search(r"class\s+(\w+)\s*:", extra_cls)
if match:
exclude_types.add(match.group(1))
for enum_def in self.enums:
if enum_def.name in exclude_types:
continue
code += enum_def.to_python_class(enhancements)
code += "\n\n"
for alias, target in enhancements.get("aliases", {}).items():
code += f"{alias} = {target}\n\n"
for type_def in self.types:
if type_def.name in exclude_types:
continue
code += type_def.to_python_dataclass(enhancements)
code += "\n\n"
for extra_cls in enhancements.get("extra_dataclasses", []):
code += extra_cls
code += "\n\n"
for extra_alias in enhancements.get("extra_type_aliases", []):
code += extra_alias
code += "\n\n"
if self.events:
code += "# BiDi Event Name to Parameter Type Mapping\n"
code += "EVENT_NAME_MAPPING = {\n"
for event_def in self.events:
method_parts = event_def.method.split(".")
if len(method_parts) == 2:
event_name = self._convert_method_to_event_name(method_parts[1])
code += f' "{event_name}": "{event_def.method}",\n'
for extra_evt in enhancements.get("extra_events", []):
code += f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n'
code += "}\n\n"
if self.name == "browsingContext":
code += """def _deserialize_info_list(items: list) -> list | None:
\"\"\"Recursively deserialize a list of dicts to Info objects.
Args:
items: List of dicts from the API response
Returns:
List of Info objects with properly nested children, or None if empty
\"\"\"
if not items or not isinstance(items, list):
return None
result = []
for item in items:
if isinstance(item, dict):
# Recursively deserialize children only if the key exists in response
children_list = None
if "children" in item:
children_list = _deserialize_info_list(item.get("children", []))
info = Info(
children=children_list,
client_window=item.get("clientWindow"),
context=item.get("context"),
original_opener=item.get("originalOpener"),
url=item.get("url"),
user_context=item.get("userContext"),
parent=item.get("parent"),
)
result.append(info)
return result if result else None
"""
code += "\n\n"
if False:
pass
class_name = module_name_to_class_name(self.name)
class_docstrings = enhancements.get("class_docstrings", {})
module_class_doc = _docstring_text(
class_docstrings.get(class_name),
class_name,
f"WebDriver BiDi {self.name} module.",
)
code += f"class {class_name}:\n"
code += _emit_docstring(module_class_doc, 4)
code += "\n"
if self.events:
code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n"
if self.name == "script":
code += " def __init__(self, conn, driver=None) -> None:\n"
code += " self._conn = conn\n"
code += " self._driver = driver\n"
else:
code += " def __init__(self, conn) -> None:\n"
code += " self._conn = conn\n"
if self.events:
code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n"
for init_line in enhancements.get("extra_init_code", []):
code += f" {init_line}\n"
code += "\n"
exclude_methods = enhancements.get("exclude_methods", [])
if "extra_methods" in enhancements:
for extra_method in enhancements["extra_methods"]:
match = re.search(r"def\s+(\w+)\s*\(", extra_method)
if match:
exclude_methods = list(exclude_methods) + [match.group(1)]
if self.commands:
command_docstrings = enhancements.get("command_docstrings", {})
for command in self.commands:
method_name_snake = command._camel_to_snake(command.name)
if method_name_snake in exclude_methods:
continue
method_enhancements = enhancements.get(method_name_snake, {})
if method_name_snake in command_docstrings and "docstring" not in method_enhancements:
method_enhancements = {**method_enhancements, "docstring": command_docstrings[method_name_snake]}
code += command.to_python_method(method_enhancements)
code += "\n"
elif not self.events and not enhancements.get("extra_methods", []):
code += " pass\n"
for extra_method in enhancements.get("extra_methods", []):
code += extra_method
code += "\n"
if self.events:
code += """
def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int:
\"\"\"Add an event handler.
Args:
event: The event to subscribe to.
callback: The callback function to execute on event.
contexts: The context IDs to subscribe to (optional).
Returns:
The callback ID.
\"\"\"
return self._event_manager.add_event_handler(event, callback, contexts)
def remove_event_handler(self, event: str, callback_id: int) -> None:
\"\"\"Remove an event handler.
Args:
event: The event to unsubscribe from.
callback_id: The callback ID.
\"\"\"
return self._event_manager.remove_event_handler(event, callback_id)
def clear_event_handlers(self) -> None:
\"\"\"Clear all event handlers.\"\"\"
return self._event_manager.clear_event_handlers()
"""
if self.events:
code += "\n# Event Info Type Aliases\n"
event_type_aliases = enhancements.get("event_type_aliases", {})
for event_def in self.events:
method_parts = event_def.method.split(".")
if len(method_parts) == 2:
event_name = self._convert_method_to_event_name(method_parts[1])
if event_name in event_type_aliases:
type_name = event_type_aliases[event_name]
code += f"# Event: {event_def.method}\n"
code += f"{event_def.name} = {type_name}\n"
else:
code += event_def.to_python_dataclass()
code += "\n"
code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n"
code += "_globals = globals()\n"
code += f"{class_name}.EVENT_CONFIGS = {{\n"
for event_def in self.events:
method_parts = event_def.method.split(".")
if len(method_parts) == 2:
event_name = self._convert_method_to_event_name(method_parts[1])
getter = f'_globals.get("{event_def.name}", dict)'
condition = f'_globals.get("{event_def.name}")'
event_class = f"{getter} if {condition} else dict"
single_line = (
f' "{event_name}": EventConfig("{event_name}", "{event_def.method}", {event_class}),'
)
if len(single_line) > 120:
code += f' "{event_name}": EventConfig(\n'
code += f' "{event_name}",\n'
code += f' "{event_def.method}",\n'
code += f" {event_class},\n"
code += " ),\n"
else:
code += single_line + "\n"
for extra_evt in enhancements.get("extra_events", []):
ek = extra_evt["event_key"]
be = extra_evt["bidi_event"]
ec = extra_evt["event_class"]
code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n'
code += "}\n"
if "field(" in code:
dataclass_import_pattern = r"from dataclasses import dataclass\n"
if re.search(dataclass_import_pattern, code):
code = re.sub(
dataclass_import_pattern,
"from dataclasses import dataclass, field\n",
code,
count=1,
)
elif "from dataclasses import" not in code:
code = code.replace(
"from typing import Any\n",
"from dataclasses import field\nfrom typing import Any\n",
)
return code
class CddlParser:
"""Parse CDDL specification files."""
def __init__(self, cddl_path: str):
"""Initialize parser with CDDL file path."""
self.cddl_path = Path(cddl_path)
self.content = ""
self.modules: dict[str, CddlModule] = {}
self.definitions: dict[str, str] = {}
self.event_names: set[str] = set()
self._read_file()
def _read_file(self) -> None:
"""Read and preprocess CDDL file."""
if not self.cddl_path.exists():
raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}")
with open(self.cddl_path, encoding="utf-8") as f:
self.content = f.read()
logger.info(f"Loaded CDDL file: {self.cddl_path}")
def parse(self) -> dict[str, CddlModule]:
"""Parse CDDL content and return modules."""
content = self._remove_comments(self.content)
self._extract_definitions(content)
self._extract_event_names()
self._extract_types()
self._extract_events()
self._extract_commands()
if not self.modules:
module_name = self.cddl_path.stem
default_module = CddlModule(name=module_name)
self.modules[module_name] = default_module
logger.warning(f"No modules found in CDDL, creating default: {module_name}")
return self.modules
def _remove_comments(self, content: str) -> str:
"""Remove comments from CDDL content."""
lines = content.split("\n")
cleaned = []
for line in lines:
if ";" in line and not line.strip().startswith(";"):
line = line[: line.index(";")]
elif line.strip().startswith(";"):
continue
cleaned.append(line)
return "\n".join(cleaned)
def _extract_definitions(self, content: str) -> None:
"""Extract CDDL definitions (type definitions, commands, etc.)."""
pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)"
for match in re.finditer(pattern, content, re.DOTALL):
name = match.group(1).strip()
definition = match.group(2).strip()
self.definitions[name] = definition
logger.debug(f"Extracted definition: {name}")
def _extract_event_names(self) -> None:
"""Extract event names from event union definitions.
Event union definitions follow pattern:
module.ModuleEvent = (
module.EventName1 //
module.EventName2 //
...
)
"""
for def_name, def_content in self.definitions.items():
if "Event" in def_name and re.search(r"\w+\.\w+", def_content):
event_refs = re.findall(r"(\w+\.\w+)", def_content)
for event_ref in event_refs:
self.event_names.add(event_ref)
logger.debug(f"Identified event: {event_ref} (from {def_name})")
def _extract_types(self) -> None:
"""Extract type definitions from parsed definitions."""
for def_name, def_content in self.definitions.items():
if "." not in def_name:
continue
if "method:" in def_content:
continue
if "." in def_name:
module_name, type_name = def_name.rsplit(".", 1)
if module_name not in self.modules:
self.modules[module_name] = CddlModule(name=module_name)
if self._is_enum_definition(def_content):
values = self._extract_enum_values(def_content)
if values:
enum_def = CddlEnum(
module=module_name,
name=type_name,
values=values,
description=f"{type_name}",
)
self.modules[module_name].enums.append(enum_def)
logger.debug(f"Found enum: {def_name} with {len(values)} values")
else:
fields = self._extract_type_fields(def_content)
if fields:
type_def = CddlTypeDefinition(
module=module_name,
name=type_name,
fields=fields,
description=f"{type_name}",
)
self.modules[module_name].types.append(type_def)
logger.debug(f"Found type: {def_name} with {len(fields)} fields")
def _is_enum_definition(self, definition: str) -> bool:
"""Check if a definition is an enum (string union with /).
Enums are defined as: "value1" / "value2" / "value3"
"""
clean_def = definition.strip()
if "{" in clean_def or "}" in clean_def:
return False
return " / " in clean_def and '"' in clean_def
def _extract_enum_values(self, enum_definition: str) -> list[str]:
"""Extract individual values from an enum definition.
Enums are defined as: "value1" / "value2" / "value3"
Can span multiple lines.
"""
values = []
parts = enum_definition.split("/")
for part in parts:
part = part.strip()
match = re.search(r'"([^"]*)"', part)
if match:
value = match.group(1)
values.append(value)
logger.debug(f"Extracted enum value: {value}")
return values
@staticmethod
def _normalize_cddl_type(field_type: str) -> str:
"""Normalize a CDDL type expression to a simple Python-compatible form.
Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and
replaces interval/constraint expressions with their base types so that
the caller can safely check for nested struct syntax.
Examples:
'(float .ge 0.0) .default 1.0' -> 'float'
'(float .ge 0.0) / null' -> 'float / null'
'(0.0...360.0) / null' -> 'float / null'
'-90.0..90.0' -> 'float'
'float / null .default null' -> 'float / null'
"""
result = field_type
result = re.sub(r"\s*\.default\s+\S+", "", result)
result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result)
result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result)
result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result)
return result.strip()
def _extract_type_fields(self, type_definition: str) -> dict[str, str]:
"""Extract fields from a type definition block."""
fields = {}
clean_def = type_definition.strip()
if clean_def.startswith("{"):
clean_def = clean_def[1:]
if clean_def.endswith("}"):
clean_def = clean_def[:-1]
for line in clean_def.split("\n"):
line = line.strip()
if not line or "Extensible" in line or line.startswith("//"):
continue
match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if not match:
match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if match:
field_name = match.group(1).strip()
field_type = match.group(2).strip()
normalized_type = self._normalize_cddl_type(field_type)
if "{" not in normalized_type and "(" not in normalized_type:
fields[field_name] = normalized_type
logger.debug(f"Extracted field {field_name}: {normalized_type}")
return fields
def _extract_events(self) -> None:
"""Extract event definitions from parsed definitions.
Events are definitions that:
1. Are listed in an event union (e.g., BrowsingContextEvent)
2. Have method: "..." and params: ... fields
Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType)
"""
event_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)")
for def_name, def_content in self.definitions.items():
if def_name not in self.event_names:
continue
match = event_pattern.search(def_content)
if match:
method = match.group(1)
params_type = match.group(2)
if "." in method:
module_name, _ = method.split(".", 1)
if module_name not in self.modules:
self.modules[module_name] = CddlModule(name=module_name)
_, event_name = def_name.rsplit(".", 1)
event = CddlEvent(
module=module_name,
name=event_name,
method=method,
params_type=params_type,
description=f"Event: {method}",
)
self.modules[module_name].events.append(event)
logger.debug(f"Found event: {def_name} (method={method}, params={params_type})")
def _extract_commands(self) -> None:
"""Extract command definitions from parsed definitions."""
command_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)")
for def_name, def_content in self.definitions.items():
if def_name in self.event_names:
continue
matches = list(command_pattern.finditer(def_content))
if matches:
for match in matches:
method = match.group(1)
params_type = match.group(2)
if "." in method:
module_name, command_name = method.split(".", 1)
if module_name not in self.modules:
self.modules[module_name] = CddlModule(name=module_name)
params, required_params = self._extract_parameters_and_required(params_type)
cmd = CddlCommand(
module=module_name,
name=command_name,
params=params,
required_params=required_params,
description=f"Execute {method}",
)
self.modules[module_name].commands.append(cmd)
logger.debug(f"Found command: {method} with params {params_type}")
def _extract_parameters(self, params_type: str, _seen: set[str] | None = None) -> dict[str, str]:
"""Extract parameters from a parameter type definition.
Handles both struct types ({...}) and top-level union types (TypeA / TypeB),
merging all fields from each alternative as optional parameters.
"""
params, _ = self._extract_parameters_and_required(params_type, _seen)
return params
def _extract_parameters_and_required(
self, params_type: str, _seen: set[str] | None = None
) -> tuple[dict[str, str], set[str]]:
"""Extract parameters and track which are required from a parameter type definition.
Returns:
Tuple of (params dict, required_params set)
"""
params = {}
required = set()
if _seen is None:
_seen = set()
if params_type in _seen:
return params, required
_seen.add(params_type)
if params_type not in self.definitions:
logger.debug(f"Parameter type not found: {params_type}")
return params, required
definition = self.definitions[params_type]
stripped = definition.strip()
if not stripped.startswith("{") and "/" in stripped and "//" not in stripped:
alternatives = [a.strip() for a in stripped.split("/") if a.strip()]
all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives)
if all_named:
for alt_type in alternatives:
alt_params, _ = self._extract_parameters_and_required(alt_type, _seen)
params.update(alt_params)
return params, required
clean_def = stripped
if clean_def.startswith("{"):
clean_def = clean_def[1:]
if clean_def.endswith("}"):
clean_def = clean_def[:-1]
for line in clean_def.split("\n"):
line = line.strip()
if not line or "Extensible" in line:
continue
is_optional = line.startswith("?")
match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if not match:
match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if match:
param_name = match.group(1).strip()
param_type = match.group(2).strip()
normalized_type = self._normalize_cddl_type(param_type)
if "{" not in normalized_type and "(" not in normalized_type:
params[param_name] = normalized_type
if not is_optional:
required.add(param_name)
logger.debug(
f"Extracted param {param_name}: {normalized_type} "
f"(required={not is_optional}) from {params_type}"
)
return params, required
def module_name_to_class_name(module_name: str) -> str:
"""Convert module name to class name (PascalCase).
Handles both camelCase (browsingContext) and snake_case (browsing_context).
"""
if "_" in module_name:
return "".join(word.capitalize() for word in module_name.split("_"))
else:
return module_name[0].upper() + module_name[1:] if module_name else ""
def module_name_to_filename(module_name: str) -> str:
"""Convert module name to Python filename (snake_case).
Handles both camelCase (browsingContext) and snake_case (browsing_context).
Special cases:
- browsingContext -> browsing_context
- webExtension -> webextension
"""
camel_to_snake_map = {
"browsingContext": "browsing_context",
"webExtension": "webextension",
}
if module_name in camel_to_snake_map:
return camel_to_snake_map[module_name]
if "_" in module_name:
return module_name
else:
import re
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None:
"""Generate __init__.py file for the module."""
init_path = output_path / "__init__.py"
code = f"""{SHARED_HEADER}
from __future__ import annotations
"""
for module_name in sorted(modules.keys()):
class_name = module_name_to_class_name(module_name)
filename = module_name_to_filename(module_name)
code += f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n"
code += "\n__all__ = [\n"
for module_name in sorted(modules.keys()):
class_name = module_name_to_class_name(module_name)
code += f' "{class_name}",\n'
code += "]\n"
with open(init_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"Generated: {init_path}")
def generate_common_file(output_path: Path) -> None:
"""Generate common.py file with shared utilities."""
common_path = output_path / "common.py"
code = (
"# Licensed to the Software Freedom Conservancy (SFC) under one\n"
"# or more contributor license agreements. See the NOTICE file\n"
"# distributed with this work for additional information\n"
"# regarding copyright ownership. The SFC licenses this file\n"
"# to you under the Apache License, Version 2.0 (the\n"
'# "License"); you may not use this file except in compliance\n'
"# with the License. You may obtain a copy of the License at\n"
"#\n"
"# http://www.apache.org/licenses/LICENSE-2.0\n"
"#\n"
"# Unless required by applicable law or agreed to in writing,\n"
"# software distributed under the License is distributed on an\n"
'# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n'
"# KIND, either express or implied. See the License for the\n"
"# specific language governing permissions and limitations\n"
"# under the License.\n"
"\n"
'"""Common utilities for BiDi command construction."""\n'
"\n"
"from __future__ import annotations\n"
"\n"
"from collections.abc import Generator\n"
"from typing import Any\n"
"\n"
"\n"
"def command_builder(\n"
" method: str, params: dict[str, Any] | None = None\n"
") -> Generator[dict[str, Any], Any, Any]:\n"
' """Build a BiDi command generator.\n'
"\n"
" Args:\n"
' method: The BiDi method name (e.g., "session.status", "browser.close")\n'
" params: The parameters for the command. If omitted, an empty\n"
" dictionary is sent.\n"
"\n"
" Yields:\n"
" A dictionary representing the BiDi command\n"
"\n"
" Returns:\n"
" The result from the BiDi command execution\n"
' """\n'
" if params is None:\n"
" params = {}\n"
' result = yield {"method": method, "params": params}\n'
" return result\n"
)
with open(common_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"Generated: {common_path}")
def generate_console_file(output_path: Path) -> None:
"""Generate console.py file with Console enum helper."""
console_path = output_path / "console.py"
code = (
"# Licensed to the Software Freedom Conservancy (SFC) under one\n"
"# or more contributor license agreements. See the NOTICE file\n"
"# distributed with this work for additional information\n"
"# regarding copyright ownership. The SFC licenses this file\n"
"# to you under the Apache License, Version 2.0 (the\n"
'# "License"); you may not use this file except in compliance\n'
"# with the License. You may obtain a copy of the License at\n"
"#\n"
"# http://www.apache.org/licenses/LICENSE-2.0\n"
"#\n"
"# Unless required by applicable law or agreed to in writing,\n"
"# software distributed under the License is distributed on an\n"
'# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n'
"# KIND, either express or implied. See the License for the\n"
"# specific language governing permissions and limitations\n"
"# under the License.\n"
"\n"
"from enum import Enum\n"
"\n"
"\n"
"class Console(Enum):\n"
' ALL = "all"\n'
' LOG = "log"\n'
' ERROR = "error"\n'
)
with open(console_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"Generated: {console_path}")
def main(
cddl_file: str,
output_dir: str,
spec_version: str = "1.0",
enhancements_manifest: str | None = None,
) -> None:
"""Main entry point.
Args:
cddl_file: Path to CDDL specification file
output_dir: Output directory for generated modules
spec_version: BiDi spec version
enhancements_manifest: Path to enhancement manifest Python file
"""
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"WebDriver BiDi Code Generator v{__version__}")
logger.info(f"Input CDDL: {cddl_file}")
logger.info(f"Output directory: {output_path}")
logger.info(f"Spec version: {spec_version}")
manifest = load_enhancements_manifest(enhancements_manifest)
if manifest:
logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}")
parser = CddlParser(cddl_file)
modules = parser.parse()
logger.info(f"Parsed {len(modules)} modules")
preserved_python_files = {"py.typed", "cdp.py"}
for file_path in output_path.glob("*.py"):
if file_path.name not in preserved_python_files and not file_path.name.startswith("_"):
file_path.unlink()
logger.debug(f"Removed: {file_path}")
for module_name, module in sorted(modules.items()):
filename = module_name_to_filename(module_name)
module_path = output_path / f"{filename}.py"
module_enhancements = manifest.get("enhancements", {}).get(module_name, {})
full_module_enhancements = {
**module_enhancements,
"dataclass_methods": manifest.get("dataclass_methods", {}),
"method_docstrings": manifest.get("method_docstrings", {}),
}
with open(module_path, "w", encoding="utf-8") as f:
f.write(module.generate_code(full_module_enhancements))
logger.info(f"Generated: {module_path}")
generate_init_file(output_path, modules)
generate_common_file(output_path)
generate_console_file(output_path)
py_typed_path = output_path / "py.typed"
py_typed_path.touch()
logger.info(f"Generated type marker: {py_typed_path}")
logger.info("Code generation complete!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate Python WebDriver BiDi modules from CDDL specification")
parser.add_argument(
"cddl_file",
help="Path to CDDL specification file",
)
parser.add_argument(
"output_dir",
help="Output directory for generated Python modules",
)
parser.add_argument(
"spec_version",
nargs="?",
default="1.0",
help="BiDi spec version (default: 1.0)",
)
parser.add_argument(
"--enhancements-manifest",
default=None,
help="Path to enhancement manifest Python file (optional)",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Enable verbose logging",
)
args = parser.parse_args()
if args.verbose:
logging.getLogger("generate_bidi").setLevel(logging.DEBUG)
try:
main(
args.cddl_file,
args.output_dir,
args.spec_version,
args.enhancements_manifest,
)
sys.exit(0)
except Exception as e:
logger.error(f"Generation failed: {e}", exc_info=True)
sys.exit(1)