Path: blob/trunk/py/selenium/webdriver/remote/websocket_connection.py
1864 views
# Licensed to the Software Freedom Conservancy (SFC) under one1# or more contributor license agreements. See the NOTICE file2# distributed with this work for additional information3# regarding copyright ownership. The SFC licenses this file4# to you under the Apache License, Version 2.0 (the5# "License"); you may not use this file except in compliance6# with the License. You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing,11# software distributed under the License is distributed on an12# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY13# KIND, either express or implied. See the License for the14# specific language governing permissions and limitations15# under the License.16import json17import logging18from ssl import CERT_NONE19from threading import Thread20from time import sleep2122from websocket import WebSocketApp # type: ignore2324from selenium.common import WebDriverException2526logger = logging.getLogger(__name__)272829class WebSocketConnection:30_response_wait_timeout = 3031_response_wait_interval = 0.13233_max_log_message_size = 99993435def __init__(self, url):36self.callbacks = {}37self.session_id = None38self.url = url3940self._id = 041self._messages = {}42self._started = False4344self._start_ws()45self._wait_until(lambda: self._started)4647def close(self):48self._ws_thread.join(timeout=self._response_wait_timeout)49self._ws.close()50self._started = False51self._ws = None5253def execute(self, command):54self._id += 155payload = self._serialize_command(command)56payload["id"] = self._id57if self.session_id:58payload["sessionId"] = self.session_id5960data = json.dumps(payload)61logger.debug(f"-> {data}"[: self._max_log_message_size])62self._ws.send(data)6364current_id = self._id65self._wait_until(lambda: current_id in self._messages)66response = self._messages.pop(current_id)6768if "error" in response:69error = response["error"]70if "message" in response:71error_msg = f"{error}: {response['message']}"72raise WebDriverException(error_msg)73else:74raise WebDriverException(error)75else:76result = response["result"]77return self._deserialize_result(result, command)7879def add_callback(self, event, callback):80event_name = event.event_class81if event_name not in self.callbacks:82self.callbacks[event_name] = []8384def _callback(params):85callback(event.from_json(params))8687self.callbacks[event_name].append(_callback)88return id(_callback)8990on = add_callback9192def remove_callback(self, event, callback_id):93event_name = event.event_class94if event_name in self.callbacks:95for callback in self.callbacks[event_name]:96if id(callback) == callback_id:97self.callbacks[event_name].remove(callback)98return99100def _serialize_command(self, command):101return next(command)102103def _deserialize_result(self, result, command):104try:105_ = command.send(result)106raise WebDriverException("The command's generator function did not exit when expected!")107except StopIteration as exit:108return exit.value109110def _start_ws(self):111def on_open(ws):112self._started = True113114def on_message(ws, message):115self._process_message(message)116117def on_error(ws, error):118logger.debug(f"error: {error}")119ws.close()120121def run_socket():122if self.url.startswith("wss://"):123self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True)124else:125self._ws.run_forever(suppress_origin=True)126127self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error)128self._ws_thread = Thread(target=run_socket)129self._ws_thread.start()130131def _process_message(self, message):132message = json.loads(message)133logger.debug(f"<- {message}"[: self._max_log_message_size])134135if "id" in message:136self._messages[message["id"]] = message137138if "method" in message:139params = message["params"]140for callback in self.callbacks.get(message["method"], []):141Thread(target=callback, args=(params,)).start()142143def _wait_until(self, condition):144timeout = self._response_wait_timeout145interval = self._response_wait_interval146147while timeout > 0:148result = condition()149if result:150return result151else:152timeout -= interval153sleep(interval)154155156