Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
SeleniumHQ
GitHub Repository: SeleniumHQ/Selenium
Path: blob/trunk/py/generate_bidi.py
10191 views
1
# Licensed to the Software Freedom Conservancy (SFC) under one
2
# or more contributor license agreements. See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership. The SFC licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License. You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied. See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
19
"""
20
Generate Python WebDriver BiDi command modules from CDDL specification.
21
22
This generator reads CDDL (Concise Data Definition Language) specification files
23
and produces Python type definitions and command classes that conform to the
24
WebDriver BiDi protocol.
25
26
Usage:
27
python generate_bidi.py <cddl_file> <output_dir> <spec_version>
28
29
Example:
30
python generate_bidi.py local.cddl ./selenium/webdriver/common/bidi 1.0
31
"""
32
33
import argparse
34
import importlib.util
35
import logging
36
import re
37
import sys
38
from dataclasses import dataclass, field
39
from enum import Enum
40
from pathlib import Path
41
from textwrap import indent as tw_indent
42
from typing import Any
43
44
__version__ = "1.0.0"
45
46
# Logging setup
47
log_level = logging.INFO
48
logging.basicConfig(level=log_level)
49
logger = logging.getLogger("generate_bidi")
50
51
# File headers
52
SHARED_HEADER = """# DO NOT EDIT THIS FILE!
53
#
54
# This file is generated from the WebDriver BiDi specification. If you need to make
55
# changes, edit the generator and regenerate all of the modules."""
56
57
# Split header: comments section and imports section are separate so a
58
# module_docstring can be injected between them (before imports → real __doc__).
59
_MODULE_HEADER_COMMENTS = f"""{SHARED_HEADER}
60
#
61
# WebDriver BiDi module: {{}}
62
"""
63
_MODULE_HEADER_IMPORTS = "from __future__ import annotations\n\n"
64
65
66
def indent(s: str, n: int) -> str:
67
"""Indent a string by n spaces."""
68
return tw_indent(s, n * " ")
69
70
71
def _docstring_text(custom: str | None, fallback_name: str, fallback_desc: str = "") -> str:
72
"""Select the appropriate raw docstring text (no triple-quotes).
73
74
Priority: manifest custom string > CDDL description (if different from name) > 'ClassName.'
75
"""
76
if custom:
77
return custom.strip()
78
if fallback_desc and fallback_desc != fallback_name:
79
return fallback_desc
80
return f"{fallback_name}."
81
82
83
def _emit_docstring(text: str, indent_width: int) -> str:
84
r"""Produce a PEP 257-compliant docstring block with a trailing newline.
85
86
Single-line output: <indent>\"\"\"text.\"\"\"\n
87
Multi-line output:
88
<indent>\"\"\"
89
<indent>First line.
90
<indent>
91
<indent>Continuation.
92
<indent>\"\"\"
93
The opening and closing triple-quotes each occupy their own line for
94
multi-line strings so that inspect.getdoc() / Sphinx dedent cleanly.
95
"""
96
prefix = " " * indent_width
97
stripped = text.strip()
98
lines = stripped.splitlines()
99
if len(lines) <= 1:
100
return f'{prefix}"""{stripped}"""\n'
101
content = "\n".join(f"{prefix}{line}" if line.strip() else "" for line in lines)
102
return f'{prefix}"""\n{content}\n{prefix}"""\n'
103
104
105
def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]:
106
"""Load enhancement manifest from a Python file.
107
108
Args:
109
manifest_path: Path to Python file containing ENHANCEMENTS dict
110
111
Returns:
112
Dictionary with enhancement rules, or empty dict if no manifest provided
113
"""
114
if not manifest_path:
115
return {}
116
117
manifest_file = Path(manifest_path)
118
if not manifest_file.exists():
119
logger.warning(f"Enhancement manifest not found: {manifest_path}")
120
return {}
121
122
try:
123
spec = importlib.util.spec_from_file_location("bidi_enhancements", manifest_file)
124
if spec is None or spec.loader is None:
125
logger.warning(f"Could not load manifest: {manifest_path}")
126
return {}
127
128
module = importlib.util.module_from_spec(spec)
129
spec.loader.exec_module(module)
130
131
enhancements = getattr(module, "ENHANCEMENTS", {})
132
dataclass_methods = getattr(module, "DATACLASS_METHOD_TEMPLATES", {})
133
method_docstrings = getattr(module, "DATACLASS_METHOD_DOCSTRINGS", {})
134
135
logger.info(f"Loaded enhancement manifest from: {manifest_path}")
136
logger.debug(f"Enhancements for modules: {list(enhancements.keys())}")
137
138
return {
139
"enhancements": enhancements,
140
"dataclass_methods": dataclass_methods,
141
"method_docstrings": method_docstrings,
142
}
143
except Exception as e:
144
logger.error(f"Failed to load enhancement manifest: {e}", exc_info=True)
145
return {}
146
147
148
class CddlType(Enum):
149
"""CDDL type mappings to Python types."""
150
151
TSTR = "str" # text string
152
TEXT = "str" # text (alias)
153
UINT = "int" # unsigned integer
154
INT = "int" # signed integer
155
NINT = "int" # negative integer
156
BOOL = "bool" # boolean
157
NULL = "None" # null
158
ANY = "Any" # any type
159
160
@classmethod
161
def get_annotation(cls, cddl_type: str) -> str:
162
"""Get Python type annotation for a CDDL type."""
163
cddl_type = cddl_type.strip().lower()
164
165
# Handle basic types
166
for member in cls:
167
if cddl_type == member.name.lower():
168
return member.value
169
170
# Handle composite types
171
if cddl_type.startswith("["): # Array
172
inner = cddl_type.strip("[]+ ")
173
inner_type = cls.get_annotation(inner)
174
return f"list[{inner_type}]"
175
176
if cddl_type.startswith("{"): # Map/Dict
177
return "dict[str, Any]"
178
179
# Default to Any for unknown types
180
return "Any"
181
182
183
@dataclass
184
class CddlCommand:
185
"""Represents a CDDL command definition."""
186
187
module: str
188
name: str
189
params: dict[str, str] = field(default_factory=dict)
190
required_params: set[str] = field(default_factory=set)
191
result: str | None = None
192
description: str = ""
193
194
def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str:
195
"""Generate Python method code for this command.
196
197
Args:
198
enhancements: Dictionary with enhancement rules for this method
199
"""
200
enhancements = enhancements or {}
201
method_name = self._camel_to_snake(self.name)
202
203
# Build parameter list with type hints
204
# Check if there's a params_override for user-friendly named arguments
205
params_to_use = self.params
206
if "params_override" in enhancements:
207
params_to_use = enhancements["params_override"]
208
209
param_strs = []
210
param_names = [] # Keep track of parameter names for later use
211
for param_name, param_type in params_to_use.items():
212
if param_type in ["bool", "str", "int"]:
213
python_type = param_type
214
else:
215
python_type = CddlType.get_annotation(param_type)
216
snake_param = self._camel_to_snake(param_name)
217
param_names.append((param_name, snake_param))
218
param_strs.append(f"{snake_param}: {python_type} | None = None")
219
220
if param_strs:
221
# Check if full signature would exceed line length limit (120 chars)
222
single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):"
223
if len(single_line_signature) > 120:
224
# Format parameters on multiple lines
225
body = f" def {method_name}(\n"
226
body += " self,\n"
227
for i, param_str in enumerate(param_strs):
228
if i < len(param_strs) - 1:
229
body += f" {param_str},\n"
230
else:
231
body += f" {param_str},\n"
232
body += " ):\n"
233
else:
234
param_list = "self, " + ", ".join(param_strs)
235
body = f" def {method_name}({param_list}):\n"
236
else:
237
body = f" def {method_name}(self):\n"
238
docstring = enhancements.get("docstring") or self.description or f"Execute {self.module}.{self.name}."
239
body += _emit_docstring(docstring, 8)
240
241
# Add automatic validation for required parameters
242
# (This is used unless there's no required_params, in which case all params are optional)
243
if self.required_params:
244
method_snake = self._camel_to_snake(self.name)
245
for param_name, snake_param in param_names:
246
if param_name in self.required_params:
247
body += f" if {snake_param} is None:\n"
248
msg = f"{method_snake}() missing required argument:"
249
error_message = f"{msg} {snake_param!r}"
250
body += f" raise TypeError({error_message!r})\n"
251
body += "\n"
252
253
# Add validation if specified in enhancements (for additional business logic validation)
254
if "validate" in enhancements:
255
validate_func = enhancements["validate"]
256
# Build parameter list for validation function
257
param_args = ", ".join(f"{snake}={snake}" for _, snake in param_names)
258
body += f" {validate_func}({param_args})\n"
259
body += "\n"
260
261
# Add transformation and preprocessing
262
# First, check if any transform is needed
263
if "transform" in enhancements:
264
transform_spec = enhancements["transform"]
265
if isinstance(transform_spec, dict):
266
# New format with explicit function and result parameter
267
transform_func = transform_spec.get("func")
268
result_param = transform_spec.get("result_param", "params")
269
input_params = [
270
transform_spec.get(k) for k in ["allowed", "destination_folder"] if transform_spec.get(k)
271
]
272
273
if transform_func and result_param:
274
body += f" {result_param} = None\n"
275
param_args = ", ".join(input_params)
276
body += f" {result_param} = {transform_func}({param_args})\n"
277
body += "\n"
278
else:
279
# Legacy format for backward compatibility
280
transform_func = transform_spec
281
if self.name == "setDownloadBehavior":
282
body += " download_behavior = None\n"
283
body += f" download_behavior = {transform_func}(allowed, destination_folder)\n"
284
body += "\n"
285
286
# Add preprocessing for serialization (check for to_bidi_dict method)
287
if "preprocess" in enhancements:
288
preprocess_rules = enhancements["preprocess"]
289
for param_name, preprocess_type in preprocess_rules.items():
290
snake_param = self._camel_to_snake(param_name)
291
if preprocess_type == "check_serialize_method":
292
body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n"
293
body += f" {snake_param} = {snake_param}.to_bidi_dict()\n"
294
body += "\n"
295
296
# Build params dict
297
body += " params = {\n"
298
299
# If there's a transform with a result parameter, map it to the BiDi protocol name
300
if "transform" in enhancements and isinstance(enhancements["transform"], dict):
301
transform_spec = enhancements["transform"]
302
result_param = transform_spec.get("result_param")
303
304
# Map the result parameter to the original CDDL parameter name
305
if result_param == "download_behavior":
306
body += ' "downloadBehavior": download_behavior,\n'
307
# Add remaining parameters that weren't part of the transform
308
for cddl_param_name in self.params:
309
if cddl_param_name not in ["downloadBehavior"]:
310
snake_name = self._camel_to_snake(cddl_param_name)
311
body += f' "{cddl_param_name}": {snake_name},\n'
312
else:
313
# Standard parameter mapping from CDDL
314
for param_name, snake_param in param_names:
315
body += f' "{param_name}": {snake_param},\n'
316
317
body += " }\n"
318
body += " params = {k: v for k, v in params.items() if v is not None}\n"
319
body += f' cmd = command_builder("{self.module}.{self.name}", params)\n'
320
body += " result = self._conn.execute(cmd)\n"
321
322
# Add response handling for extraction/deserialization
323
if "extract_field" in enhancements:
324
extract_field = enhancements["extract_field"]
325
extract_property = enhancements.get("extract_property")
326
327
# Check if we also need to deserialize the extracted field
328
deserialize_rules = enhancements.get("deserialize", {})
329
330
if extract_property:
331
# Extract property from list items
332
body += f' if result and "{extract_field}" in result:\n'
333
body += f' items = result.get("{extract_field}", [])\n'
334
body += " return [\n"
335
body += f' item.get("{extract_property}")\n'
336
body += " for item in items\n"
337
body += " if isinstance(item, dict)\n"
338
body += " ]\n"
339
body += " return []\n"
340
elif extract_field in deserialize_rules:
341
# Extract field and deserialize to typed objects
342
type_name = deserialize_rules[extract_field]
343
body += f' if result and "{extract_field}" in result:\n'
344
body += f' items = result.get("{extract_field}", [])\n'
345
body += " return [\n"
346
body += f" {type_name}(\n"
347
body += self._generate_field_args(extract_field, type_name)
348
body += " )\n"
349
body += " for item in items\n"
350
body += " if isinstance(item, dict)\n"
351
body += " ]\n"
352
body += " return []\n"
353
else:
354
# Simple field extraction (return the value directly, not wrapped in result dict)
355
body += f' if result and "{extract_field}" in result:\n'
356
body += f' extracted = result.get("{extract_field}")\n'
357
body += " return extracted\n"
358
body += " return result\n"
359
elif "deserialize" in enhancements:
360
# Deserialize response to typed objects (legacy, without extract_field)
361
deserialize_rules = enhancements["deserialize"]
362
for response_field, type_name in deserialize_rules.items():
363
body += f' if result and "{response_field}" in result:\n'
364
body += f' items = result.get("{response_field}", [])\n'
365
body += " return [\n"
366
body += f" {type_name}(\n"
367
body += self._generate_field_args(response_field, type_name)
368
body += " )\n"
369
body += " for item in items\n"
370
body += " if isinstance(item, dict)\n"
371
body += " ]\n"
372
body += " return []\n"
373
else:
374
# No special response handling, just return the result
375
body += " return result\n"
376
377
return body
378
379
def _generate_field_args(self, response_field: str, type_name: str) -> str:
380
"""Generate constructor arguments for deserializing response objects.
381
382
For now, this handles ClientWindowInfo and Info specifically.
383
Could be extended to be more generic.
384
"""
385
if type_name == "ClientWindowInfo":
386
return (
387
' active=item.get("active"),\n'
388
' client_window=item.get("clientWindow"),\n'
389
' height=item.get("height"),\n'
390
' state=item.get("state"),\n'
391
' width=item.get("width"),\n'
392
' x=item.get("x"),\n'
393
' y=item.get("y")\n'
394
)
395
elif type_name == "Info":
396
return (
397
' children=_deserialize_info_list(item.get("children", [])),\n'
398
' client_window=item.get("clientWindow"),\n'
399
' context=item.get("context"),\n'
400
' original_opener=item.get("originalOpener"),\n'
401
' url=item.get("url"),\n'
402
' user_context=item.get("userContext"),\n'
403
' parent=item.get("parent")\n'
404
)
405
# For other types, return empty
406
return ""
407
408
@staticmethod
409
def _camel_to_snake(name: str) -> str:
410
"""Convert camelCase to snake_case."""
411
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
412
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
413
414
415
@dataclass
416
class CddlTypeDefinition:
417
"""Represents a CDDL type definition."""
418
419
module: str
420
name: str
421
fields: dict[str, str] = field(default_factory=dict)
422
description: str = ""
423
424
def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str:
425
"""Generate Python dataclass code for this type.
426
427
Args:
428
enhancements: Dictionary containing dataclass_methods and method_docstrings
429
"""
430
enhancements = enhancements or {}
431
dataclass_methods = enhancements.get("dataclass_methods", {})
432
method_docstrings = enhancements.get("method_docstrings", {})
433
434
# Generate class name from type name (keep it as-is, don't split on underscores)
435
class_name = self.name
436
code = "@dataclass\n"
437
code += f"class {class_name}:\n"
438
class_docstrings = enhancements.get("class_docstrings", {})
439
class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description)
440
code += _emit_docstring(class_doc, 4)
441
code += "\n"
442
443
if not self.fields:
444
code += " pass\n"
445
else:
446
for field_name, field_type in self.fields.items():
447
# Convert CDDL type to Python type
448
python_type = self._get_python_type(field_type)
449
snake_name = CddlCommand._camel_to_snake(field_name)
450
451
# Check if the CDDL field type is a quoted string literal (e.g., type: "key")
452
# These are discriminant fields: auto-populate and exclude from __init__
453
# so callers don't need to pass them as positional or keyword arguments.
454
literal_match = re.match(r'^"([^"]+)"$', field_type.strip())
455
if literal_match:
456
literal_value = literal_match.group(1)
457
code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n'
458
# Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax)
459
elif python_type.startswith("list["):
460
# Remove the trailing ' | None' from list types since default_factory=list ensures non-None
461
type_annotation = python_type.replace(" | None", "")
462
code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n"
463
# Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax)
464
elif python_type.startswith("dict["):
465
# Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None
466
type_annotation = python_type.replace(" | None", "")
467
code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n"
468
else:
469
code += f" {snake_name}: {python_type} = None\n"
470
471
# Add custom methods if defined for this class
472
if class_name in dataclass_methods:
473
code += "\n"
474
methods_dict = dataclass_methods[class_name]
475
docstrings_dict = method_docstrings.get(class_name, {})
476
477
for method_name in methods_dict:
478
method_impl = methods_dict[method_name]
479
docstring = docstrings_dict.get(method_name, "")
480
code += f" def {method_name}(self):\n"
481
if docstring:
482
code += f' """{docstring}"""\n'
483
code += f" {method_impl}\n"
484
code += "\n"
485
486
return code
487
488
@staticmethod
489
def _get_python_type(cddl_type: str) -> str:
490
"""Convert CDDL type to Python type annotation using Python 3.10+ syntax."""
491
cddl_type = cddl_type.strip().lower()
492
493
# Handle basic types
494
type_mapping = {
495
"tstr": "str",
496
"text": "str",
497
"uint": "int",
498
"int": "int",
499
"nint": "int",
500
"bool": "bool",
501
"null": "None",
502
}
503
504
for cddl, python in type_mapping.items():
505
if cddl_type == cddl:
506
# Use Python 3.10+ union syntax: type | None
507
return f"{python} | None"
508
509
# Handle arrays
510
if cddl_type.startswith("["):
511
inner = cddl_type.strip("[]+ ")
512
inner_type = CddlTypeDefinition._get_python_type(inner)
513
# Remove " | None" from inner type since it might be wrapped
514
if " | None" in inner_type:
515
inner_base = inner_type.replace(" | None", "")
516
return f"list[{inner_base} | None] | None"
517
return f"list[{inner_type}] | None"
518
519
# Handle maps/dicts
520
if cddl_type.startswith("{"):
521
return "dict[str, Any] | None"
522
523
# Default to Any for unknown/complex types
524
return "Any | None"
525
526
527
@dataclass
528
class CddlEnum:
529
"""Represents a CDDL enum definition (string union)."""
530
531
module: str
532
name: str
533
values: list[str] = field(default_factory=list)
534
description: str = ""
535
536
def to_python_class(self, enhancements: dict[str, Any] | None = None) -> str:
537
"""Generate Python enum class code.
538
539
Generates a simple class with string constants to match the existing
540
pattern in the codebase (e.g., ClientWindowState).
541
"""
542
enhancements = enhancements or {}
543
class_name = self.name
544
class_docstrings = enhancements.get("class_docstrings", {})
545
class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description)
546
code = f"class {class_name}:\n"
547
code += _emit_docstring(class_doc, 4)
548
code += "\n"
549
550
for value in self.values:
551
# Convert value to UPPER_SNAKE_CASE constant name
552
const_name = self._value_to_const_name(value)
553
code += f' {const_name} = "{value}"\n'
554
555
return code
556
557
@staticmethod
558
def _value_to_const_name(value: str) -> str:
559
"""Convert enum string value to constant name.
560
561
Examples:
562
"none" -> "NONE"
563
"portrait-primary" -> "PORTRAIT_PRIMARY"
564
"interactive" -> "INTERACTIVE"
565
"""
566
# Replace hyphens with underscores
567
const_name = value.replace("-", "_")
568
# Convert to uppercase
569
return const_name.upper()
570
571
572
@dataclass
573
class CddlEvent:
574
"""Represents a CDDL event definition (incoming message from browser)."""
575
576
module: str
577
name: str
578
method: str
579
params_type: str
580
description: str = ""
581
582
def to_python_dataclass(self) -> str:
583
"""Generate Python dataclass code for the event info type.
584
585
Returns a dataclass code that attempts to use globals().get() for safety.
586
"""
587
class_name = self.name
588
589
# Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info")
590
# The params_type comes from the CDDL and includes module prefix
591
type_name = self.params_type.split(".")[-1] if "." in self.params_type else self.params_type
592
593
# Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly
594
# (NavigationInfo will be created as an alias to it)
595
if type_name == "NavigationInfo":
596
type_name = "BaseNavigationInfo"
597
598
# Generate type alias using globals().get() for safety
599
code = f"# Event: {self.method}\n"
600
code += f"{class_name} = globals().get('{type_name}', dict) # Fallback to dict if type not defined\n"
601
602
return code
603
604
605
@dataclass
606
class CddlModule:
607
"""Represents a CDDL module (e.g., script, network, browsing_context)."""
608
609
name: str
610
commands: list[CddlCommand] = field(default_factory=list)
611
types: list[CddlTypeDefinition] = field(default_factory=list)
612
enums: list[CddlEnum] = field(default_factory=list)
613
events: list[CddlEvent] = field(default_factory=list)
614
615
@staticmethod
616
def _convert_method_to_event_name(method_suffix: str) -> str:
617
"""Convert BiDi method suffix to friendly event name.
618
619
Examples:
620
"contextCreated" -> "context_created"
621
"navigationStarted" -> "navigation_started"
622
"userPromptOpened" -> "user_prompt_opened"
623
"""
624
# Convert camelCase to snake_case
625
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix)
626
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
627
628
def generate_code(self, enhancements: dict[str, Any] | None = None) -> str:
629
"""Generate Python code for this module.
630
631
Args:
632
enhancements: Dictionary with module-level enhancements
633
"""
634
enhancements = enhancements or {}
635
module_docstring = enhancements.get("module_docstring", "")
636
code = _MODULE_HEADER_COMMENTS.format(self.name)
637
if module_docstring:
638
code += _emit_docstring(module_docstring, 0) + "\n"
639
code += _MODULE_HEADER_IMPORTS
640
641
# Collect needed imports to avoid duplicates
642
needs_command_builder = bool(self.commands)
643
needs_dataclass = self.commands or self.types or self.events
644
needs_callable = self.events
645
646
stdlib_imports = []
647
local_imports = []
648
649
# Add imports (field import will be added conditionally after code generation)
650
if needs_callable:
651
stdlib_imports.append("from collections.abc import Callable")
652
if needs_dataclass:
653
stdlib_imports.append("from dataclasses import dataclass")
654
stdlib_imports.append("from typing import Any")
655
656
if needs_command_builder:
657
local_imports.append("from selenium.webdriver.common.bidi.common import command_builder")
658
if self.events:
659
local_imports.append(
660
"from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager"
661
)
662
663
code += "\n".join(stdlib_imports) + "\n"
664
if local_imports:
665
code += "\n" + "\n".join(local_imports) + "\n"
666
667
code += "\n"
668
669
# Add helper function definitions from enhancements
670
# Collect all referenced helper functions (validate, transform)
671
helper_funcs_to_add = set()
672
for cmd in self.commands:
673
method_name_snake = cmd._camel_to_snake(cmd.name)
674
method_enhancements = enhancements.get(method_name_snake, {})
675
if "validate" in method_enhancements:
676
helper_funcs_to_add.add(("validate", method_enhancements["validate"]))
677
if "transform" in method_enhancements and isinstance(method_enhancements["transform"], dict):
678
transform_spec = method_enhancements["transform"]
679
if "func" in transform_spec:
680
helper_funcs_to_add.add(("transform", transform_spec["func"]))
681
682
# Generate helper functions if needed
683
if helper_funcs_to_add:
684
for func_type, func_name in sorted(helper_funcs_to_add):
685
if func_type == "validate" and func_name == "validate_download_behavior":
686
code += """def validate_download_behavior(
687
allowed: bool | None,
688
destination_folder: str | None,
689
user_contexts: Any | None = None,
690
) -> None:
691
\"\"\"Validate download behavior parameters.
692
693
Args:
694
allowed: Whether downloads are allowed
695
destination_folder: Destination folder for downloads
696
user_contexts: Optional list of user contexts
697
698
Raises:
699
ValueError: If parameters are invalid
700
\"\"\"
701
if allowed is True and not destination_folder:
702
raise ValueError("destination_folder is required when allowed=True")
703
if allowed is False and destination_folder:
704
raise ValueError("destination_folder should not be provided when allowed=False")
705
706
707
"""
708
elif func_type == "transform" and func_name == "transform_download_params":
709
code += """def transform_download_params(
710
allowed: bool | None,
711
destination_folder: str | None,
712
) -> dict[str, Any] | None:
713
\"\"\"Transform download parameters into download_behavior object.
714
715
Args:
716
allowed: Whether downloads are allowed
717
destination_folder: Destination folder for downloads (accepts str or
718
pathlib.Path; will be coerced to str)
719
720
Returns:
721
Dictionary representing the download_behavior object, or None if allowed is None
722
\"\"\"
723
if allowed is True:
724
return {
725
"type": "allowed",
726
# Coerce pathlib.Path (or any path-like) to str so the BiDi
727
# protocol always receives a plain JSON string.
728
"destinationFolder": str(destination_folder) if destination_folder is not None else None,
729
}
730
elif allowed is False:
731
return {"type": "denied"}
732
else: # None — reset to browser default (sent as JSON null)
733
return None
734
735
736
"""
737
738
# Generate enums first (excluding those in exclude_types)
739
exclude_types = set(enhancements.get("exclude_types", []))
740
741
# Also exclude any types that have extra_dataclasses overrides
742
# Extract class names from extra_dataclasses strings
743
for extra_cls in enhancements.get("extra_dataclasses", []):
744
# Match "class ClassName:" pattern
745
match = re.search(r"class\s+(\w+)\s*:", extra_cls)
746
if match:
747
exclude_types.add(match.group(1))
748
749
for enum_def in self.enums:
750
if enum_def.name in exclude_types:
751
continue
752
code += enum_def.to_python_class(enhancements)
753
code += "\n\n"
754
755
# Emit module-level aliases from enhancement manifest (e.g. LogLevel = Level)
756
for alias, target in enhancements.get("aliases", {}).items():
757
code += f"{alias} = {target}\n\n"
758
759
# Generate type dataclasses, skipping any overridden by extra_dataclasses
760
for type_def in self.types:
761
if type_def.name in exclude_types:
762
continue
763
code += type_def.to_python_dataclass(enhancements)
764
code += "\n\n"
765
766
# Emit extra dataclasses from enhancement manifest (non-CDDL additions)
767
for extra_cls in enhancements.get("extra_dataclasses", []):
768
code += extra_cls
769
code += "\n\n"
770
771
# Emit extra type aliases from enhancement manifest (e.g., union types for events)
772
for extra_alias in enhancements.get("extra_type_aliases", []):
773
code += extra_alias
774
code += "\n\n"
775
776
# NOTE: Don't generate event type aliases here - they reference types that may not be defined yet
777
# They will be generated after the class definition instead
778
779
# Generate EVENT_NAME_MAPPING for the module (before the module class)
780
if self.events:
781
# Generate EVENT_NAME_MAPPING for the module
782
code += "# BiDi Event Name to Parameter Type Mapping\n"
783
code += "EVENT_NAME_MAPPING = {\n"
784
for event_def in self.events:
785
# Convert method name to user-friendly event name
786
# e.g., "browsingContext.contextCreated" -> "context_created"
787
method_parts = event_def.method.split(".")
788
if len(method_parts) == 2:
789
event_name = self._convert_method_to_event_name(method_parts[1])
790
code += f' "{event_name}": "{event_def.method}",\n'
791
# Extra events not in the CDDL spec (e.g. Chromium-specific events)
792
for extra_evt in enhancements.get("extra_events", []):
793
code += f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n'
794
code += "}\n\n"
795
796
# Add custom method function definitions before the class (for browsingContext)
797
if self.name == "browsingContext":
798
# Add helper function for recursive Info deserialization
799
code += """def _deserialize_info_list(items: list) -> list | None:
800
\"\"\"Recursively deserialize a list of dicts to Info objects.
801
802
Args:
803
items: List of dicts from the API response
804
805
Returns:
806
List of Info objects with properly nested children, or None if empty
807
\"\"\"
808
if not items or not isinstance(items, list):
809
return None
810
811
result = []
812
for item in items:
813
if isinstance(item, dict):
814
# Recursively deserialize children only if the key exists in response
815
children_list = None
816
if "children" in item:
817
children_list = _deserialize_info_list(item.get("children", []))
818
info = Info(
819
children=children_list,
820
client_window=item.get("clientWindow"),
821
context=item.get("context"),
822
original_opener=item.get("originalOpener"),
823
url=item.get("url"),
824
user_context=item.get("userContext"),
825
parent=item.get("parent"),
826
)
827
result.append(info)
828
return result if result else None
829
830
831
"""
832
code += "\n\n"
833
834
# EventConfig, _EventWrapper, and _EventManager are imported from
835
# selenium.webdriver.common.bidi._event_manager (see import section above)
836
# rather than being duplicated inline in every generated module.
837
if False: # placeholder to preserve indentation structure
838
pass
839
840
# Generate class
841
# Convert module name (camelCase or snake_case) to proper class name (PascalCase)
842
class_name = module_name_to_class_name(self.name)
843
class_docstrings = enhancements.get("class_docstrings", {})
844
module_class_doc = _docstring_text(
845
class_docstrings.get(class_name),
846
class_name,
847
f"WebDriver BiDi {self.name} module.",
848
)
849
code += f"class {class_name}:\n"
850
code += _emit_docstring(module_class_doc, 4)
851
code += "\n"
852
853
# Add EVENT_CONFIGS dict if there are events
854
if self.events:
855
code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined
856
857
if self.name == "script":
858
code += " def __init__(self, conn, driver=None) -> None:\n"
859
code += " self._conn = conn\n"
860
code += " self._driver = driver\n"
861
else:
862
code += " def __init__(self, conn) -> None:\n"
863
code += " self._conn = conn\n"
864
865
# Initialize _event_manager if there are events
866
if self.events:
867
code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n"
868
869
# Append extra init code from enhancements (e.g. self.intercepts = [])
870
for init_line in enhancements.get("extra_init_code", []):
871
code += f" {init_line}\n"
872
873
code += "\n"
874
875
# Generate command methods
876
exclude_methods = enhancements.get("exclude_methods", [])
877
878
# Automatically exclude methods that are defined in extra_methods
879
# to prevent generating duplicates
880
if "extra_methods" in enhancements:
881
for extra_method in enhancements["extra_methods"]:
882
# Extract method name from "def method_name("
883
match = re.search(r"def\s+(\w+)\s*\(", extra_method)
884
if match:
885
exclude_methods = list(exclude_methods) + [match.group(1)]
886
887
if self.commands:
888
command_docstrings = enhancements.get("command_docstrings", {})
889
for command in self.commands:
890
# Get method-specific enhancements
891
# Convert command name to snake_case to match enhancement manifest keys
892
method_name_snake = command._camel_to_snake(command.name)
893
if method_name_snake in exclude_methods:
894
continue
895
method_enhancements = enhancements.get(method_name_snake, {})
896
# Inject command_docstrings entry if no per-method docstring is set
897
if method_name_snake in command_docstrings and "docstring" not in method_enhancements:
898
method_enhancements = {**method_enhancements, "docstring": command_docstrings[method_name_snake]}
899
code += command.to_python_method(method_enhancements)
900
code += "\n"
901
elif not self.events and not enhancements.get("extra_methods", []):
902
code += " pass\n"
903
904
# Emit extra methods from enhancement manifest
905
for extra_method in enhancements.get("extra_methods", []):
906
code += extra_method
907
code += "\n"
908
909
# Add delegating event handler methods if events are present
910
if self.events:
911
code += """
912
def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int:
913
\"\"\"Add an event handler.
914
915
Args:
916
event: The event to subscribe to.
917
callback: The callback function to execute on event.
918
contexts: The context IDs to subscribe to (optional).
919
920
Returns:
921
The callback ID.
922
\"\"\"
923
return self._event_manager.add_event_handler(event, callback, contexts)
924
925
def remove_event_handler(self, event: str, callback_id: int) -> None:
926
\"\"\"Remove an event handler.
927
928
Args:
929
event: The event to unsubscribe from.
930
callback_id: The callback ID.
931
\"\"\"
932
return self._event_manager.remove_event_handler(event, callback_id)
933
934
def clear_event_handlers(self) -> None:
935
\"\"\"Clear all event handlers.\"\"\"
936
return self._event_manager.clear_event_handlers()
937
"""
938
939
# Generate event info type aliases AFTER the class definition
940
# This ensures all types are available when we create the aliases
941
if self.events:
942
code += "\n# Event Info Type Aliases\n"
943
# Check for explicit event_type_aliases in the enhancement manifest
944
event_type_aliases = enhancements.get("event_type_aliases", {})
945
for event_def in self.events:
946
# Convert method name to user-friendly event name
947
method_parts = event_def.method.split(".")
948
if len(method_parts) == 2:
949
event_name = self._convert_method_to_event_name(method_parts[1])
950
# Check if there's an explicit alias defined in the enhancement manifest
951
if event_name in event_type_aliases:
952
# Use the alias directly
953
type_name = event_type_aliases[event_name]
954
code += f"# Event: {event_def.method}\n"
955
code += f"{event_def.name} = {type_name}\n"
956
else:
957
# Fall back to the original behavior
958
code += event_def.to_python_dataclass()
959
code += "\n"
960
961
# Now populate EVENT_CONFIGS after the aliases are defined
962
code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n"
963
# Use globals() to look up types dynamically to handle missing types gracefully
964
code += "_globals = globals()\n"
965
code += f"{class_name}.EVENT_CONFIGS = {{\n"
966
for event_def in self.events:
967
# Convert method name to user-friendly event name
968
method_parts = event_def.method.split(".")
969
if len(method_parts) == 2:
970
event_name = self._convert_method_to_event_name(method_parts[1])
971
# Try to get event class from globals, default to dict if not found
972
getter = f'_globals.get("{event_def.name}", dict)'
973
condition = f'_globals.get("{event_def.name}")'
974
event_class = f"{getter} if {condition} else dict"
975
976
# Build the entry line and check if it exceeds 120 chars
977
single_line = (
978
f' "{event_name}": EventConfig("{event_name}", "{event_def.method}", {event_class}),'
979
)
980
981
if len(single_line) > 120:
982
# Break into multiple lines
983
code += f' "{event_name}": EventConfig(\n'
984
code += f' "{event_name}",\n'
985
code += f' "{event_def.method}",\n'
986
code += f" {event_class},\n"
987
code += " ),\n"
988
else:
989
code += single_line + "\n"
990
# Extra events not in the CDDL spec
991
for extra_evt in enhancements.get("extra_events", []):
992
ek = extra_evt["event_key"]
993
be = extra_evt["bidi_event"]
994
ec = extra_evt["event_class"]
995
code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n'
996
code += "}\n"
997
998
# Check if field() is actually used in the generated code
999
# If so, add the field import after the dataclass import
1000
if "field(" in code:
1001
# Find where to insert the field import
1002
# It should go after "from dataclasses import dataclass" line
1003
dataclass_import_pattern = r"from dataclasses import dataclass\n"
1004
if re.search(dataclass_import_pattern, code):
1005
code = re.sub(
1006
dataclass_import_pattern,
1007
"from dataclasses import dataclass, field\n",
1008
code,
1009
count=1,
1010
)
1011
elif "from dataclasses import" not in code:
1012
# If there's no dataclasses import yet, add field import after typing
1013
code = code.replace(
1014
"from typing import Any\n",
1015
"from dataclasses import field\nfrom typing import Any\n",
1016
)
1017
1018
return code
1019
1020
1021
class CddlParser:
1022
"""Parse CDDL specification files."""
1023
1024
def __init__(self, cddl_path: str):
1025
"""Initialize parser with CDDL file path."""
1026
self.cddl_path = Path(cddl_path)
1027
self.content = ""
1028
self.modules: dict[str, CddlModule] = {}
1029
self.definitions: dict[str, str] = {}
1030
self.event_names: set[str] = set() # Names of definitions that are events
1031
self._read_file()
1032
1033
def _read_file(self) -> None:
1034
"""Read and preprocess CDDL file."""
1035
if not self.cddl_path.exists():
1036
raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}")
1037
1038
with open(self.cddl_path, encoding="utf-8") as f:
1039
self.content = f.read()
1040
1041
logger.info(f"Loaded CDDL file: {self.cddl_path}")
1042
1043
def parse(self) -> dict[str, CddlModule]:
1044
"""Parse CDDL content and return modules."""
1045
# Remove comments
1046
content = self._remove_comments(self.content)
1047
1048
# Extract all definitions
1049
self._extract_definitions(content)
1050
1051
# Extract event names from event union definitions
1052
self._extract_event_names()
1053
1054
# Extract type definitions by module
1055
self._extract_types()
1056
1057
# Extract event definitions by module
1058
self._extract_events()
1059
1060
# Extract command definitions by module
1061
self._extract_commands()
1062
1063
# If no modules found, create a default one from the filename
1064
if not self.modules:
1065
module_name = self.cddl_path.stem
1066
default_module = CddlModule(name=module_name)
1067
self.modules[module_name] = default_module
1068
logger.warning(f"No modules found in CDDL, creating default: {module_name}")
1069
1070
return self.modules
1071
1072
def _remove_comments(self, content: str) -> str:
1073
"""Remove comments from CDDL content."""
1074
# CDDL uses ; for comments to end of line
1075
lines = content.split("\n")
1076
cleaned = []
1077
for line in lines:
1078
if ";" in line and not line.strip().startswith(";"):
1079
line = line[: line.index(";")]
1080
elif line.strip().startswith(";"):
1081
continue
1082
cleaned.append(line)
1083
return "\n".join(cleaned)
1084
1085
def _extract_definitions(self, content: str) -> None:
1086
"""Extract CDDL definitions (type definitions, commands, etc.)."""
1087
# Match pattern: Name = Definition
1088
# Handles multiline definitions properly
1089
pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)"
1090
1091
for match in re.finditer(pattern, content, re.DOTALL):
1092
name = match.group(1).strip()
1093
definition = match.group(2).strip()
1094
self.definitions[name] = definition
1095
logger.debug(f"Extracted definition: {name}")
1096
1097
def _extract_event_names(self) -> None:
1098
"""Extract event names from event union definitions.
1099
1100
Event union definitions follow pattern:
1101
module.ModuleEvent = (
1102
module.EventName1 //
1103
module.EventName2 //
1104
...
1105
)
1106
"""
1107
for def_name, def_content in self.definitions.items():
1108
# Check if this looks like an event union (name ends with "Event") and
1109
# contains a module-qualified reference like "module.EventName".
1110
# Handles both single-item (no //) and multi-item (// separated) unions.
1111
if "Event" in def_name and re.search(r"\w+\.\w+", def_content):
1112
# Extract event names from the union (works for single and multi-item)
1113
event_refs = re.findall(r"(\w+\.\w+)", def_content)
1114
for event_ref in event_refs:
1115
self.event_names.add(event_ref)
1116
logger.debug(f"Identified event: {event_ref} (from {def_name})")
1117
1118
def _extract_types(self) -> None:
1119
"""Extract type definitions from parsed definitions."""
1120
# Type definitions follow pattern: module.TypeName = { field: type, ... }
1121
# They have dots in the name and curly braces in the content
1122
# But they DON'T have method: "..." pattern (which means it's not a command)
1123
# Enums follow pattern: module.EnumName = "value1" / "value2" / ...
1124
1125
for def_name, def_content in self.definitions.items():
1126
# Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible")
1127
if "." not in def_name:
1128
continue
1129
1130
# Skip if it's a command (contains method: pattern)
1131
if "method:" in def_content:
1132
continue
1133
1134
# Extract module.TypeName
1135
if "." in def_name:
1136
module_name, type_name = def_name.rsplit(".", 1)
1137
1138
# Create module if not exists
1139
if module_name not in self.modules:
1140
self.modules[module_name] = CddlModule(name=module_name)
1141
1142
# Check if this is an enum (string union with /)
1143
if self._is_enum_definition(def_content):
1144
# Extract enum values
1145
values = self._extract_enum_values(def_content)
1146
if values:
1147
enum_def = CddlEnum(
1148
module=module_name,
1149
name=type_name,
1150
values=values,
1151
description=f"{type_name}",
1152
)
1153
self.modules[module_name].enums.append(enum_def)
1154
logger.debug(f"Found enum: {def_name} with {len(values)} values")
1155
else:
1156
# Extract fields from type definition
1157
fields = self._extract_type_fields(def_content)
1158
1159
if fields: # Only create type if it has fields
1160
type_def = CddlTypeDefinition(
1161
module=module_name,
1162
name=type_name,
1163
fields=fields,
1164
description=f"{type_name}",
1165
)
1166
self.modules[module_name].types.append(type_def)
1167
logger.debug(f"Found type: {def_name} with {len(fields)} fields")
1168
1169
def _is_enum_definition(self, definition: str) -> bool:
1170
"""Check if a definition is an enum (string union with /).
1171
1172
Enums are defined as: "value1" / "value2" / "value3"
1173
"""
1174
# Clean whitespace
1175
clean_def = definition.strip()
1176
1177
# Must not have curly braces (that would be a type definition)
1178
if "{" in clean_def or "}" in clean_def:
1179
return False
1180
1181
# Must contain the union operator / surrounded by quotes
1182
# Pattern: "something" / "something_else"
1183
return " / " in clean_def and '"' in clean_def
1184
1185
def _extract_enum_values(self, enum_definition: str) -> list[str]:
1186
"""Extract individual values from an enum definition.
1187
1188
Enums are defined as: "value1" / "value2" / "value3"
1189
Can span multiple lines.
1190
"""
1191
values = []
1192
1193
# Clean the definition and extract quoted strings
1194
# Split by / and extract quoted values
1195
parts = enum_definition.split("/")
1196
1197
for part in parts:
1198
part = part.strip()
1199
1200
# Extract quoted string - use search instead of match to find quotes anywhere
1201
match = re.search(r'"([^"]*)"', part)
1202
if match:
1203
value = match.group(1)
1204
values.append(value)
1205
logger.debug(f"Extracted enum value: {value}")
1206
1207
return values
1208
1209
@staticmethod
1210
def _normalize_cddl_type(field_type: str) -> str:
1211
"""Normalize a CDDL type expression to a simple Python-compatible form.
1212
1213
Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and
1214
replaces interval/constraint expressions with their base types so that
1215
the caller can safely check for nested struct syntax.
1216
1217
Examples:
1218
'(float .ge 0.0) .default 1.0' -> 'float'
1219
'(float .ge 0.0) / null' -> 'float / null'
1220
'(0.0...360.0) / null' -> 'float / null'
1221
'-90.0..90.0' -> 'float'
1222
'float / null .default null' -> 'float / null'
1223
"""
1224
result = field_type
1225
# Remove trailing .default <value> annotations
1226
result = re.sub(r"\s*\.default\s+\S+", "", result)
1227
# Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType
1228
result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result)
1229
# Replace parenthesised numeric interval types: (0.0...360.0) -> float
1230
result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result)
1231
# Replace bare numeric interval types: -90.0..90.0 -> float
1232
result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result)
1233
return result.strip()
1234
1235
def _extract_type_fields(self, type_definition: str) -> dict[str, str]:
1236
"""Extract fields from a type definition block."""
1237
fields = {}
1238
1239
# Remove outer braces
1240
clean_def = type_definition.strip()
1241
if clean_def.startswith("{"):
1242
clean_def = clean_def[1:]
1243
if clean_def.endswith("}"):
1244
clean_def = clean_def[:-1]
1245
1246
# Parse each line for field: type patterns
1247
for line in clean_def.split("\n"):
1248
line = line.strip()
1249
if not line or "Extensible" in line or line.startswith("//"):
1250
continue
1251
1252
# Match pattern: [?] fieldName: type
1253
match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
1254
if not match:
1255
# Try without optional marker
1256
match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
1257
1258
if match:
1259
field_name = match.group(1).strip()
1260
field_type = match.group(2).strip()
1261
normalized_type = self._normalize_cddl_type(field_type)
1262
1263
# Skip lines that are part of nested definitions
1264
if "{" not in normalized_type and "(" not in normalized_type:
1265
fields[field_name] = normalized_type
1266
logger.debug(f"Extracted field {field_name}: {normalized_type}")
1267
1268
return fields
1269
1270
def _extract_events(self) -> None:
1271
"""Extract event definitions from parsed definitions.
1272
1273
Events are definitions that:
1274
1. Are listed in an event union (e.g., BrowsingContextEvent)
1275
2. Have method: "..." and params: ... fields
1276
1277
Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType)
1278
"""
1279
# Find definitions that are in the event_names set
1280
event_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)")
1281
1282
for def_name, def_content in self.definitions.items():
1283
# Skip if not identified as an event
1284
if def_name not in self.event_names:
1285
continue
1286
1287
# Extract method and params
1288
match = event_pattern.search(def_content)
1289
if match:
1290
method = match.group(1) # e.g., "browsingContext.contextCreated"
1291
params_type = match.group(2) # e.g., "browsingContext.Info"
1292
1293
# Extract module name from method
1294
if "." in method:
1295
module_name, _ = method.split(".", 1)
1296
1297
# Create module if not exists
1298
if module_name not in self.modules:
1299
self.modules[module_name] = CddlModule(name=module_name)
1300
1301
# Extract event name from definition name (e.g., browsingContext.ContextCreated)
1302
_, event_name = def_name.rsplit(".", 1)
1303
1304
# Create event
1305
event = CddlEvent(
1306
module=module_name,
1307
name=event_name,
1308
method=method,
1309
params_type=params_type,
1310
description=f"Event: {method}",
1311
)
1312
1313
self.modules[module_name].events.append(event)
1314
logger.debug(f"Found event: {def_name} (method={method}, params={params_type})")
1315
1316
def _extract_commands(self) -> None:
1317
"""Extract command definitions from parsed definitions."""
1318
# Find command definitions that follow pattern: module.Command = (method: "...", params: ...)
1319
command_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)")
1320
1321
for def_name, def_content in self.definitions.items():
1322
# Skip definitions that are events (they share the same pattern)
1323
if def_name in self.event_names:
1324
continue
1325
matches = list(command_pattern.finditer(def_content))
1326
if matches:
1327
for match in matches:
1328
method = match.group(1) # e.g., "session.new"
1329
params_type = match.group(2) # e.g., "session.NewParameters"
1330
1331
# Extract module name from method
1332
if "." in method:
1333
module_name, command_name = method.split(".", 1)
1334
1335
# Create module if not exists
1336
if module_name not in self.modules:
1337
self.modules[module_name] = CddlModule(name=module_name)
1338
1339
# Extract parameters and required parameters
1340
params, required_params = self._extract_parameters_and_required(params_type)
1341
1342
# Create command
1343
cmd = CddlCommand(
1344
module=module_name,
1345
name=command_name,
1346
params=params,
1347
required_params=required_params,
1348
description=f"Execute {method}",
1349
)
1350
1351
self.modules[module_name].commands.append(cmd)
1352
logger.debug(f"Found command: {method} with params {params_type}")
1353
1354
def _extract_parameters(self, params_type: str, _seen: set[str] | None = None) -> dict[str, str]:
1355
"""Extract parameters from a parameter type definition.
1356
1357
Handles both struct types ({...}) and top-level union types (TypeA / TypeB),
1358
merging all fields from each alternative as optional parameters.
1359
"""
1360
params, _ = self._extract_parameters_and_required(params_type, _seen)
1361
return params
1362
1363
def _extract_parameters_and_required(
1364
self, params_type: str, _seen: set[str] | None = None
1365
) -> tuple[dict[str, str], set[str]]:
1366
"""Extract parameters and track which are required from a parameter type definition.
1367
1368
Returns:
1369
Tuple of (params dict, required_params set)
1370
"""
1371
params = {}
1372
required = set()
1373
1374
if _seen is None:
1375
_seen = set()
1376
if params_type in _seen:
1377
return params, required
1378
_seen.add(params_type)
1379
1380
if params_type not in self.definitions:
1381
logger.debug(f"Parameter type not found: {params_type}")
1382
return params, required
1383
1384
definition = self.definitions[params_type]
1385
1386
# Handle top-level type alias that is a union of other named types:
1387
# e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest
1388
# These definitions contain a single line with "/" separating type names
1389
# (not the double-slash "//" used for command unions).
1390
stripped = definition.strip()
1391
if not stripped.startswith("{") and "/" in stripped and "//" not in stripped:
1392
# Each token separated by "/" should be a named type reference
1393
alternatives = [a.strip() for a in stripped.split("/") if a.strip()]
1394
all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives)
1395
if all_named:
1396
# For union types, collect parameters from all alternatives
1397
# but treat them as optional since the caller only needs to pass one alternative
1398
for alt_type in alternatives:
1399
alt_params, _ = self._extract_parameters_and_required(alt_type, _seen)
1400
params.update(alt_params)
1401
# Note: We intentionally DON'T add to required, since these are union alternatives
1402
return params, required
1403
1404
# Remove the outer curly braces and split by comma
1405
# Then parse each line for key: type patterns
1406
clean_def = stripped
1407
if clean_def.startswith("{"):
1408
clean_def = clean_def[1:]
1409
if clean_def.endswith("}"):
1410
clean_def = clean_def[:-1]
1411
1412
# Split by newlines and process each line
1413
for line in clean_def.split("\n"):
1414
line = line.strip()
1415
if not line or "Extensible" in line:
1416
continue
1417
1418
# Match pattern: [?] name: type
1419
# Check if parameter has optional marker (?)
1420
is_optional = line.startswith("?")
1421
1422
# Using a simple pattern that handles optional prefix
1423
match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
1424
if not match:
1425
# Try without optional marker
1426
match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
1427
1428
if match:
1429
param_name = match.group(1).strip()
1430
param_type = match.group(2).strip()
1431
normalized_type = self._normalize_cddl_type(param_type)
1432
1433
# Skip lines that are part of nested definitions
1434
if "{" not in normalized_type and "(" not in normalized_type:
1435
params[param_name] = normalized_type
1436
if not is_optional:
1437
required.add(param_name)
1438
logger.debug(
1439
f"Extracted param {param_name}: {normalized_type} "
1440
f"(required={not is_optional}) from {params_type}"
1441
)
1442
1443
return params, required
1444
1445
1446
def module_name_to_class_name(module_name: str) -> str:
1447
"""Convert module name to class name (PascalCase).
1448
1449
Handles both camelCase (browsingContext) and snake_case (browsing_context).
1450
"""
1451
if "_" in module_name:
1452
# Snake_case: browsing_context -> BrowsingContext
1453
return "".join(word.capitalize() for word in module_name.split("_"))
1454
else:
1455
# CamelCase: browsingContext -> BrowsingContext
1456
return module_name[0].upper() + module_name[1:] if module_name else ""
1457
1458
1459
def module_name_to_filename(module_name: str) -> str:
1460
"""Convert module name to Python filename (snake_case).
1461
1462
Handles both camelCase (browsingContext) and snake_case (browsing_context).
1463
Special cases:
1464
- browsingContext -> browsing_context
1465
- webExtension -> webextension
1466
"""
1467
# Handle explicit mappings for known camelCase names
1468
camel_to_snake_map = {
1469
"browsingContext": "browsing_context",
1470
"webExtension": "webextension",
1471
}
1472
1473
if module_name in camel_to_snake_map:
1474
return camel_to_snake_map[module_name]
1475
1476
if "_" in module_name:
1477
# Already snake_case
1478
return module_name
1479
else:
1480
# Convert camelCase to snake_case for other cases
1481
# This handles cases like "myModuleName" -> "my_module_name"
1482
import re
1483
1484
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name)
1485
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
1486
1487
1488
def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None:
1489
"""Generate __init__.py file for the module."""
1490
init_path = output_path / "__init__.py"
1491
1492
code = f"""{SHARED_HEADER}
1493
1494
from __future__ import annotations
1495
1496
"""
1497
1498
for module_name in sorted(modules.keys()):
1499
class_name = module_name_to_class_name(module_name)
1500
filename = module_name_to_filename(module_name)
1501
code += f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n"
1502
1503
code += "\n__all__ = [\n"
1504
for module_name in sorted(modules.keys()):
1505
class_name = module_name_to_class_name(module_name)
1506
code += f' "{class_name}",\n'
1507
code += "]\n"
1508
1509
with open(init_path, "w", encoding="utf-8") as f:
1510
f.write(code)
1511
1512
logger.info(f"Generated: {init_path}")
1513
1514
1515
def generate_common_file(output_path: Path) -> None:
1516
"""Generate common.py file with shared utilities."""
1517
common_path = output_path / "common.py"
1518
1519
code = (
1520
"# Licensed to the Software Freedom Conservancy (SFC) under one\n"
1521
"# or more contributor license agreements. See the NOTICE file\n"
1522
"# distributed with this work for additional information\n"
1523
"# regarding copyright ownership. The SFC licenses this file\n"
1524
"# to you under the Apache License, Version 2.0 (the\n"
1525
'# "License"); you may not use this file except in compliance\n'
1526
"# with the License. You may obtain a copy of the License at\n"
1527
"#\n"
1528
"# http://www.apache.org/licenses/LICENSE-2.0\n"
1529
"#\n"
1530
"# Unless required by applicable law or agreed to in writing,\n"
1531
"# software distributed under the License is distributed on an\n"
1532
'# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n'
1533
"# KIND, either express or implied. See the License for the\n"
1534
"# specific language governing permissions and limitations\n"
1535
"# under the License.\n"
1536
"\n"
1537
'"""Common utilities for BiDi command construction."""\n'
1538
"\n"
1539
"from __future__ import annotations\n"
1540
"\n"
1541
"from collections.abc import Generator\n"
1542
"from typing import Any\n"
1543
"\n"
1544
"\n"
1545
"def command_builder(\n"
1546
" method: str, params: dict[str, Any] | None = None\n"
1547
") -> Generator[dict[str, Any], Any, Any]:\n"
1548
' """Build a BiDi command generator.\n'
1549
"\n"
1550
" Args:\n"
1551
' method: The BiDi method name (e.g., "session.status", "browser.close")\n'
1552
" params: The parameters for the command. If omitted, an empty\n"
1553
" dictionary is sent.\n"
1554
"\n"
1555
" Yields:\n"
1556
" A dictionary representing the BiDi command\n"
1557
"\n"
1558
" Returns:\n"
1559
" The result from the BiDi command execution\n"
1560
' """\n'
1561
" if params is None:\n"
1562
" params = {}\n"
1563
' result = yield {"method": method, "params": params}\n'
1564
" return result\n"
1565
)
1566
1567
with open(common_path, "w", encoding="utf-8") as f:
1568
f.write(code)
1569
1570
logger.info(f"Generated: {common_path}")
1571
1572
1573
def generate_console_file(output_path: Path) -> None:
1574
"""Generate console.py file with Console enum helper."""
1575
console_path = output_path / "console.py"
1576
1577
code = (
1578
"# Licensed to the Software Freedom Conservancy (SFC) under one\n"
1579
"# or more contributor license agreements. See the NOTICE file\n"
1580
"# distributed with this work for additional information\n"
1581
"# regarding copyright ownership. The SFC licenses this file\n"
1582
"# to you under the Apache License, Version 2.0 (the\n"
1583
'# "License"); you may not use this file except in compliance\n'
1584
"# with the License. You may obtain a copy of the License at\n"
1585
"#\n"
1586
"# http://www.apache.org/licenses/LICENSE-2.0\n"
1587
"#\n"
1588
"# Unless required by applicable law or agreed to in writing,\n"
1589
"# software distributed under the License is distributed on an\n"
1590
'# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n'
1591
"# KIND, either express or implied. See the License for the\n"
1592
"# specific language governing permissions and limitations\n"
1593
"# under the License.\n"
1594
"\n"
1595
"from enum import Enum\n"
1596
"\n"
1597
"\n"
1598
"class Console(Enum):\n"
1599
' ALL = "all"\n'
1600
' LOG = "log"\n'
1601
' ERROR = "error"\n'
1602
)
1603
1604
with open(console_path, "w", encoding="utf-8") as f:
1605
f.write(code)
1606
1607
logger.info(f"Generated: {console_path}")
1608
1609
1610
def main(
1611
cddl_file: str,
1612
output_dir: str,
1613
spec_version: str = "1.0",
1614
enhancements_manifest: str | None = None,
1615
) -> None:
1616
"""Main entry point.
1617
1618
Args:
1619
cddl_file: Path to CDDL specification file
1620
output_dir: Output directory for generated modules
1621
spec_version: BiDi spec version
1622
enhancements_manifest: Path to enhancement manifest Python file
1623
"""
1624
output_path = Path(output_dir).resolve()
1625
output_path.mkdir(parents=True, exist_ok=True)
1626
1627
logger.info(f"WebDriver BiDi Code Generator v{__version__}")
1628
logger.info(f"Input CDDL: {cddl_file}")
1629
logger.info(f"Output directory: {output_path}")
1630
logger.info(f"Spec version: {spec_version}")
1631
1632
# Load enhancement manifest
1633
manifest = load_enhancements_manifest(enhancements_manifest)
1634
if manifest:
1635
logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}")
1636
1637
# Parse CDDL
1638
parser = CddlParser(cddl_file)
1639
modules = parser.parse()
1640
1641
logger.info(f"Parsed {len(modules)} modules")
1642
1643
# Clean up existing generated files.
1644
# Keep static helper modules that are staged by Bazel (for example cdp.py)
1645
# as part of create-bidi-src.extra_srcs.
1646
preserved_python_files = {"py.typed", "cdp.py"}
1647
for file_path in output_path.glob("*.py"):
1648
if file_path.name not in preserved_python_files and not file_path.name.startswith("_"):
1649
file_path.unlink()
1650
logger.debug(f"Removed: {file_path}")
1651
1652
# Generate module files using snake_case filenames
1653
for module_name, module in sorted(modules.items()):
1654
filename = module_name_to_filename(module_name)
1655
module_path = output_path / f"{filename}.py"
1656
1657
# Get module-specific enhancements (merge with dataclass templates)
1658
module_enhancements = manifest.get("enhancements", {}).get(module_name, {})
1659
1660
# Add dataclass methods and docstrings to the enhancement data for this module
1661
full_module_enhancements = {
1662
**module_enhancements,
1663
"dataclass_methods": manifest.get("dataclass_methods", {}),
1664
"method_docstrings": manifest.get("method_docstrings", {}),
1665
}
1666
1667
with open(module_path, "w", encoding="utf-8") as f:
1668
f.write(module.generate_code(full_module_enhancements))
1669
logger.info(f"Generated: {module_path}")
1670
1671
# Generate __init__.py
1672
generate_init_file(output_path, modules)
1673
1674
# Generate common.py
1675
generate_common_file(output_path)
1676
1677
# Generate console.py
1678
generate_console_file(output_path)
1679
1680
# Create py.typed marker
1681
py_typed_path = output_path / "py.typed"
1682
py_typed_path.touch()
1683
logger.info(f"Generated type marker: {py_typed_path}")
1684
1685
logger.info("Code generation complete!")
1686
1687
1688
if __name__ == "__main__":
1689
parser = argparse.ArgumentParser(description="Generate Python WebDriver BiDi modules from CDDL specification")
1690
parser.add_argument(
1691
"cddl_file",
1692
help="Path to CDDL specification file",
1693
)
1694
parser.add_argument(
1695
"output_dir",
1696
help="Output directory for generated Python modules",
1697
)
1698
parser.add_argument(
1699
"spec_version",
1700
nargs="?",
1701
default="1.0",
1702
help="BiDi spec version (default: 1.0)",
1703
)
1704
parser.add_argument(
1705
"--enhancements-manifest",
1706
default=None,
1707
help="Path to enhancement manifest Python file (optional)",
1708
)
1709
parser.add_argument(
1710
"-v",
1711
"--verbose",
1712
action="store_true",
1713
help="Enable verbose logging",
1714
)
1715
1716
args = parser.parse_args()
1717
1718
if args.verbose:
1719
logging.getLogger("generate_bidi").setLevel(logging.DEBUG)
1720
1721
try:
1722
main(
1723
args.cddl_file,
1724
args.output_dir,
1725
args.spec_version,
1726
args.enhancements_manifest,
1727
)
1728
sys.exit(0)
1729
except Exception as e:
1730
logger.error(f"Generation failed: {e}", exc_info=True)
1731
sys.exit(1)
1732
1733