Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
SeleniumHQ
GitHub Repository: SeleniumHQ/Selenium
Path: blob/trunk/py/selenium/webdriver/remote/websocket_connection.py
3997 views
1
# Licensed to the Software Freedom Conservancy (SFC) under one
2
# or more contributor license agreements. See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership. The SFC licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License. You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied. See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
import json
19
import logging
20
from ssl import CERT_NONE
21
from threading import Thread
22
from time import sleep
23
24
from websocket import WebSocketApp
25
26
from selenium.common import WebDriverException
27
28
logger = logging.getLogger(__name__)
29
30
31
class WebSocketConnection:
32
_max_log_message_size = 9999
33
34
def __init__(self, url, timeout, interval):
35
if not isinstance(timeout, (int, float)) or timeout < 0:
36
raise WebDriverException("timeout must be a positive number")
37
if not isinstance(interval, (int, float)) or timeout < 0:
38
raise WebDriverException("interval must be a positive number")
39
40
self.url = url
41
self.response_wait_timeout = timeout
42
self.response_wait_interval = interval
43
44
self.callbacks = {}
45
self.session_id = None
46
self._id = 0
47
self._messages = {}
48
self._started = False
49
50
self._start_ws()
51
self._wait_until(lambda: self._started)
52
53
def close(self):
54
self._ws_thread.join(timeout=self.response_wait_timeout)
55
self._ws.close()
56
self._started = False
57
self._ws = None
58
59
def execute(self, command):
60
self._id += 1
61
payload = self._serialize_command(command)
62
payload["id"] = self._id
63
if self.session_id:
64
payload["sessionId"] = self.session_id
65
66
data = json.dumps(payload)
67
logger.debug(f"-> {data}"[: self._max_log_message_size])
68
self._ws.send(data)
69
70
current_id = self._id
71
self._wait_until(lambda: current_id in self._messages)
72
response = self._messages.pop(current_id)
73
74
if "error" in response:
75
error = response["error"]
76
if "message" in response:
77
error_msg = f"{error}: {response['message']}"
78
raise WebDriverException(error_msg)
79
else:
80
raise WebDriverException(error)
81
else:
82
result = response["result"]
83
return self._deserialize_result(result, command)
84
85
def add_callback(self, event, callback):
86
event_name = event.event_class
87
if event_name not in self.callbacks:
88
self.callbacks[event_name] = []
89
90
def _callback(params):
91
callback(event.from_json(params))
92
93
self.callbacks[event_name].append(_callback)
94
return id(_callback)
95
96
on = add_callback
97
98
def remove_callback(self, event, callback_id):
99
event_name = event.event_class
100
if event_name in self.callbacks:
101
for callback in self.callbacks[event_name]:
102
if id(callback) == callback_id:
103
self.callbacks[event_name].remove(callback)
104
return
105
106
def _serialize_command(self, command):
107
return next(command)
108
109
def _deserialize_result(self, result, command):
110
try:
111
_ = command.send(result)
112
raise WebDriverException("The command's generator function did not exit when expected!")
113
except StopIteration as exit:
114
return exit.value
115
116
def _start_ws(self):
117
def on_open(ws):
118
self._started = True
119
120
def on_message(ws, message):
121
self._process_message(message)
122
123
def on_error(ws, error):
124
logger.debug(f"error: {error}")
125
ws.close()
126
127
def run_socket():
128
if self.url.startswith("wss://"):
129
self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True)
130
else:
131
self._ws.run_forever(suppress_origin=True)
132
133
self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error)
134
self._ws_thread = Thread(target=run_socket, daemon=True)
135
self._ws_thread.start()
136
137
def _process_message(self, message):
138
message = json.loads(message)
139
logger.debug(f"<- {message}"[: self._max_log_message_size])
140
141
if "id" in message:
142
self._messages[message["id"]] = message
143
144
if "method" in message:
145
params = message["params"]
146
for callback in self.callbacks.get(message["method"], []):
147
Thread(target=callback, args=(params,), daemon=True).start()
148
149
def _wait_until(self, condition):
150
timeout = self.response_wait_timeout
151
interval = self.response_wait_interval
152
153
while timeout > 0:
154
result = condition()
155
if result:
156
return result
157
else:
158
timeout -= interval
159
sleep(interval)
160
161