Path: blob/trunk/py/selenium/webdriver/remote/websocket_connection.py
3997 views
# Licensed to the Software Freedom Conservancy (SFC) under one1# or more contributor license agreements. See the NOTICE file2# distributed with this work for additional information3# regarding copyright ownership. The SFC licenses this file4# to you under the Apache License, Version 2.0 (the5# "License"); you may not use this file except in compliance6# with the License. You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing,11# software distributed under the License is distributed on an12# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY13# KIND, either express or implied. See the License for the14# specific language governing permissions and limitations15# under the License.1617import json18import logging19from ssl import CERT_NONE20from threading import Thread21from time import sleep2223from websocket import WebSocketApp2425from selenium.common import WebDriverException2627logger = logging.getLogger(__name__)282930class WebSocketConnection:31_max_log_message_size = 99993233def __init__(self, url, timeout, interval):34if not isinstance(timeout, (int, float)) or timeout < 0:35raise WebDriverException("timeout must be a positive number")36if not isinstance(interval, (int, float)) or timeout < 0:37raise WebDriverException("interval must be a positive number")3839self.url = url40self.response_wait_timeout = timeout41self.response_wait_interval = interval4243self.callbacks = {}44self.session_id = None45self._id = 046self._messages = {}47self._started = False4849self._start_ws()50self._wait_until(lambda: self._started)5152def close(self):53self._ws_thread.join(timeout=self.response_wait_timeout)54self._ws.close()55self._started = False56self._ws = None5758def execute(self, command):59self._id += 160payload = self._serialize_command(command)61payload["id"] = self._id62if self.session_id:63payload["sessionId"] = self.session_id6465data = json.dumps(payload)66logger.debug(f"-> {data}"[: self._max_log_message_size])67self._ws.send(data)6869current_id = self._id70self._wait_until(lambda: current_id in self._messages)71response = self._messages.pop(current_id)7273if "error" in response:74error = response["error"]75if "message" in response:76error_msg = f"{error}: {response['message']}"77raise WebDriverException(error_msg)78else:79raise WebDriverException(error)80else:81result = response["result"]82return self._deserialize_result(result, command)8384def add_callback(self, event, callback):85event_name = event.event_class86if event_name not in self.callbacks:87self.callbacks[event_name] = []8889def _callback(params):90callback(event.from_json(params))9192self.callbacks[event_name].append(_callback)93return id(_callback)9495on = add_callback9697def remove_callback(self, event, callback_id):98event_name = event.event_class99if event_name in self.callbacks:100for callback in self.callbacks[event_name]:101if id(callback) == callback_id:102self.callbacks[event_name].remove(callback)103return104105def _serialize_command(self, command):106return next(command)107108def _deserialize_result(self, result, command):109try:110_ = command.send(result)111raise WebDriverException("The command's generator function did not exit when expected!")112except StopIteration as exit:113return exit.value114115def _start_ws(self):116def on_open(ws):117self._started = True118119def on_message(ws, message):120self._process_message(message)121122def on_error(ws, error):123logger.debug(f"error: {error}")124ws.close()125126def run_socket():127if self.url.startswith("wss://"):128self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True)129else:130self._ws.run_forever(suppress_origin=True)131132self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error)133self._ws_thread = Thread(target=run_socket, daemon=True)134self._ws_thread.start()135136def _process_message(self, message):137message = json.loads(message)138logger.debug(f"<- {message}"[: self._max_log_message_size])139140if "id" in message:141self._messages[message["id"]] = message142143if "method" in message:144params = message["params"]145for callback in self.callbacks.get(message["method"], []):146Thread(target=callback, args=(params,), daemon=True).start()147148def _wait_until(self, condition):149timeout = self.response_wait_timeout150interval = self.response_wait_interval151152while timeout > 0:153result = condition()154if result:155return result156else:157timeout -= interval158sleep(interval)159160161