Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
SeleniumHQ
GitHub Repository: SeleniumHQ/Selenium
Path: blob/trunk/py/selenium/webdriver/remote/websocket_connection.py
1864 views
1
# Licensed to the Software Freedom Conservancy (SFC) under one
2
# or more contributor license agreements. See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership. The SFC licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License. You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied. See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
import json
18
import logging
19
from ssl import CERT_NONE
20
from threading import Thread
21
from time import sleep
22
23
from websocket import WebSocketApp # type: ignore
24
25
from selenium.common import WebDriverException
26
27
logger = logging.getLogger(__name__)
28
29
30
class WebSocketConnection:
31
_response_wait_timeout = 30
32
_response_wait_interval = 0.1
33
34
_max_log_message_size = 9999
35
36
def __init__(self, url):
37
self.callbacks = {}
38
self.session_id = None
39
self.url = url
40
41
self._id = 0
42
self._messages = {}
43
self._started = False
44
45
self._start_ws()
46
self._wait_until(lambda: self._started)
47
48
def close(self):
49
self._ws_thread.join(timeout=self._response_wait_timeout)
50
self._ws.close()
51
self._started = False
52
self._ws = None
53
54
def execute(self, command):
55
self._id += 1
56
payload = self._serialize_command(command)
57
payload["id"] = self._id
58
if self.session_id:
59
payload["sessionId"] = self.session_id
60
61
data = json.dumps(payload)
62
logger.debug(f"-> {data}"[: self._max_log_message_size])
63
self._ws.send(data)
64
65
current_id = self._id
66
self._wait_until(lambda: current_id in self._messages)
67
response = self._messages.pop(current_id)
68
69
if "error" in response:
70
error = response["error"]
71
if "message" in response:
72
error_msg = f"{error}: {response['message']}"
73
raise WebDriverException(error_msg)
74
else:
75
raise WebDriverException(error)
76
else:
77
result = response["result"]
78
return self._deserialize_result(result, command)
79
80
def add_callback(self, event, callback):
81
event_name = event.event_class
82
if event_name not in self.callbacks:
83
self.callbacks[event_name] = []
84
85
def _callback(params):
86
callback(event.from_json(params))
87
88
self.callbacks[event_name].append(_callback)
89
return id(_callback)
90
91
on = add_callback
92
93
def remove_callback(self, event, callback_id):
94
event_name = event.event_class
95
if event_name in self.callbacks:
96
for callback in self.callbacks[event_name]:
97
if id(callback) == callback_id:
98
self.callbacks[event_name].remove(callback)
99
return
100
101
def _serialize_command(self, command):
102
return next(command)
103
104
def _deserialize_result(self, result, command):
105
try:
106
_ = command.send(result)
107
raise WebDriverException("The command's generator function did not exit when expected!")
108
except StopIteration as exit:
109
return exit.value
110
111
def _start_ws(self):
112
def on_open(ws):
113
self._started = True
114
115
def on_message(ws, message):
116
self._process_message(message)
117
118
def on_error(ws, error):
119
logger.debug(f"error: {error}")
120
ws.close()
121
122
def run_socket():
123
if self.url.startswith("wss://"):
124
self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True)
125
else:
126
self._ws.run_forever(suppress_origin=True)
127
128
self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error)
129
self._ws_thread = Thread(target=run_socket)
130
self._ws_thread.start()
131
132
def _process_message(self, message):
133
message = json.loads(message)
134
logger.debug(f"<- {message}"[: self._max_log_message_size])
135
136
if "id" in message:
137
self._messages[message["id"]] = message
138
139
if "method" in message:
140
params = message["params"]
141
for callback in self.callbacks.get(message["method"], []):
142
Thread(target=callback, args=(params,)).start()
143
144
def _wait_until(self, condition):
145
timeout = self._response_wait_timeout
146
interval = self._response_wait_interval
147
148
while timeout > 0:
149
result = condition()
150
if result:
151
return result
152
else:
153
timeout -= interval
154
sleep(interval)
155
156