"""Shared event management helpers for generated WebDriver BiDi modules.
``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted
identically into every generated module that exposes events. Rather than
duplicating this logic across those modules, they are defined once here and
copied into generated outputs by Bazel.
"""
from __future__ import annotations
import threading
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from selenium.webdriver.common.bidi.session import Session
@dataclass
class EventConfig:
"""Configuration for a BiDi event."""
event_key: str
bidi_event: str
event_class: type
class _EventWrapper:
"""Wrapper to provide event_class attribute for WebSocketConnection callbacks."""
def __init__(self, bidi_event: str, event_class: type):
self.event_class = bidi_event
self._python_class = event_class
def from_json(self, params: dict) -> Any:
"""Deserialize event params into the wrapped Python dataclass.
Args:
params: Raw BiDi event params with camelCase keys.
Returns:
An instance of the dataclass, or the raw dict on failure.
"""
if self._python_class is None or self._python_class is dict:
return params
try:
if hasattr(self._python_class, "from_json") and callable(self._python_class.from_json):
return self._python_class.from_json(params)
import dataclasses as dc
snake_params = {self._camel_to_snake(k): v for k, v in params.items()}
if dc.is_dataclass(self._python_class):
valid_fields = {f.name for f in dc.fields(self._python_class)}
filtered = {k: v for k, v in snake_params.items() if k in valid_fields}
return self._python_class(**filtered)
return self._python_class(**snake_params)
except Exception:
return params
@staticmethod
def _camel_to_snake(name: str) -> str:
result = [name[0].lower()]
for char in name[1:]:
if char.isupper():
result.extend(["_", char.lower()])
else:
result.append(char)
return "".join(result)
class _EventManager:
"""Manages event subscriptions and callbacks."""
def __init__(self, conn, event_configs: dict[str, EventConfig]):
self.conn = conn
self.event_configs = event_configs
self.subscriptions: dict = {}
self._event_wrappers = {}
self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()}
self._available_events = ", ".join(sorted(event_configs.keys()))
self._subscription_lock = threading.Lock()
for config in event_configs.values():
wrapper = _EventWrapper(config.bidi_event, config.event_class)
self._event_wrappers[config.bidi_event] = wrapper
def validate_event(self, event: str) -> EventConfig:
event_config = self.event_configs.get(event)
if not event_config:
raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}")
return event_config
def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None:
"""Subscribe to a BiDi event if not already subscribed."""
with self._subscription_lock:
if bidi_event not in self.subscriptions:
session = Session(self.conn)
result = session.subscribe([bidi_event], contexts=contexts)
sub_id = result.get("subscription") if isinstance(result, dict) else None
self.subscriptions[bidi_event] = {
"callbacks": [],
"subscription_id": sub_id,
}
def unsubscribe_from_event(self, bidi_event: str) -> None:
"""Unsubscribe from a BiDi event if no more callbacks exist."""
with self._subscription_lock:
entry = self.subscriptions.get(bidi_event)
if entry is not None and not entry["callbacks"]:
session = Session(self.conn)
sub_id = entry.get("subscription_id")
if sub_id:
session.unsubscribe(subscriptions=[sub_id])
else:
session.unsubscribe(events=[bidi_event])
del self.subscriptions[bidi_event]
def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None:
with self._subscription_lock:
self.subscriptions[bidi_event]["callbacks"].append(callback_id)
def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None:
with self._subscription_lock:
entry = self.subscriptions.get(bidi_event)
if entry and callback_id in entry["callbacks"]:
entry["callbacks"].remove(callback_id)
def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int:
event_config = self.validate_event(event)
event_wrapper = self._event_wrappers.get(event_config.bidi_event)
callback_id = self.conn.add_callback(event_wrapper, callback)
self.subscribe_to_event(event_config.bidi_event, contexts)
self.add_callback_to_tracking(event_config.bidi_event, callback_id)
return callback_id
def remove_event_handler(self, event: str, callback_id: int) -> None:
event_config = self.validate_event(event)
event_wrapper = self._event_wrappers.get(event_config.bidi_event)
self.conn.remove_callback(event_wrapper, callback_id)
self.remove_callback_from_tracking(event_config.bidi_event, callback_id)
self.unsubscribe_from_event(event_config.bidi_event)
def clear_event_handlers(self) -> None:
"""Clear all event handlers."""
with self._subscription_lock:
if not self.subscriptions:
return
session = Session(self.conn)
for bidi_event, entry in list(self.subscriptions.items()):
event_wrapper = self._event_wrappers.get(bidi_event)
callbacks = entry["callbacks"] if isinstance(entry, dict) else entry
if event_wrapper:
for callback_id in callbacks:
self.conn.remove_callback(event_wrapper, callback_id)
sub_id = entry.get("subscription_id") if isinstance(entry, dict) else None
if sub_id:
session.unsubscribe(subscriptions=[sub_id])
else:
session.unsubscribe(events=[bidi_event])
self.subscriptions.clear()