Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
SeleniumHQ
GitHub Repository: SeleniumHQ/Selenium
Path: blob/trunk/py/selenium/webdriver/common/bidi/network.py
4012 views
1
# Licensed to the Software Freedom Conservancy (SFC) under one
2
# or more contributor license agreements. See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership. The SFC licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License. You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied. See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
from __future__ import annotations
19
20
from collections.abc import Callable
21
from typing import Any
22
23
from selenium.webdriver.common.bidi.common import command_builder
24
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
25
26
27
class NetworkEvent:
28
"""Represents a network event."""
29
30
def __init__(self, event_class: str, **kwargs: Any) -> None:
31
self.event_class = event_class
32
self.params = kwargs
33
34
@classmethod
35
def from_json(cls, json: dict[str, Any]) -> NetworkEvent:
36
return cls(event_class=json.get("event_class", ""), **json)
37
38
39
class Network:
40
EVENTS = {
41
"before_request": "network.beforeRequestSent",
42
"response_started": "network.responseStarted",
43
"response_completed": "network.responseCompleted",
44
"auth_required": "network.authRequired",
45
"fetch_error": "network.fetchError",
46
"continue_request": "network.continueRequest",
47
"continue_auth": "network.continueWithAuth",
48
}
49
50
PHASES = {
51
"before_request": "beforeRequestSent",
52
"response_started": "responseStarted",
53
"auth_required": "authRequired",
54
}
55
56
def __init__(self, conn: WebSocketConnection) -> None:
57
self.conn = conn
58
self.intercepts: list[str] = []
59
self.callbacks: dict[str | int, Any] = {}
60
self.subscriptions: dict[str, list[int]] = {}
61
62
def _add_intercept(
63
self,
64
phases: list[str] | None = None,
65
contexts: list[str] | None = None,
66
url_patterns: list[Any] | None = None,
67
) -> dict[str, Any]:
68
"""Add an intercept to the network.
69
70
Args:
71
phases: A list of phases to intercept. Default is None (empty list).
72
contexts: A list of contexts to intercept. Default is None.
73
url_patterns: A list of URL patterns to intercept. Default is None.
74
75
Returns:
76
str: intercept id
77
"""
78
if phases is None:
79
phases = []
80
params = {}
81
if contexts is not None:
82
params["contexts"] = contexts
83
if url_patterns is not None:
84
params["urlPatterns"] = url_patterns
85
if len(phases) > 0:
86
params["phases"] = phases
87
else:
88
params["phases"] = ["beforeRequestSent"]
89
cmd = command_builder("network.addIntercept", params)
90
91
result: dict[str, Any] = self.conn.execute(cmd)
92
self.intercepts.append(result["intercept"])
93
return result
94
95
def _remove_intercept(self, intercept: str | None = None) -> None:
96
"""Remove a specific intercept, or all intercepts.
97
98
Args:
99
intercept: The intercept to remove. Default is None.
100
101
Raises:
102
ValueError: If intercept is not found.
103
104
Note:
105
If intercept is None, all intercepts will be removed.
106
"""
107
if intercept is None:
108
intercepts_to_remove = self.intercepts.copy() # create a copy before iterating
109
for intercept_id in intercepts_to_remove: # remove all intercepts
110
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept_id}))
111
self.intercepts.remove(intercept_id)
112
else:
113
try:
114
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept}))
115
self.intercepts.remove(intercept)
116
except Exception as e:
117
raise Exception(f"Exception: {e}")
118
119
def _on_request(self, event_name: str, callback: Callable[[Request], Any]) -> int:
120
"""Set a callback function to subscribe to a network event.
121
122
Args:
123
event_name: The event to subscribe to.
124
callback: The callback function to execute on event.
125
Takes Request object as argument.
126
127
Returns:
128
int: callback id
129
"""
130
event = NetworkEvent(event_name)
131
132
def _callback(event_data: NetworkEvent) -> None:
133
request = Request(
134
network=self,
135
request_id=event_data.params["request"].get("request", None),
136
body_size=event_data.params["request"].get("bodySize", None),
137
cookies=event_data.params["request"].get("cookies", None),
138
resource_type=event_data.params["request"].get("goog:resourceType", None),
139
headers=event_data.params["request"].get("headers", None),
140
headers_size=event_data.params["request"].get("headersSize", None),
141
timings=event_data.params["request"].get("timings", None),
142
url=event_data.params["request"].get("url", None),
143
)
144
callback(request)
145
146
callback_id: int = self.conn.add_callback(event, _callback)
147
148
if event_name in self.callbacks:
149
self.callbacks[event_name].append(callback_id)
150
else:
151
self.callbacks[event_name] = [callback_id]
152
153
return callback_id
154
155
def add_request_handler(
156
self,
157
event: str,
158
callback: Callable[[Request], Any],
159
url_patterns: list[Any] | None = None,
160
contexts: list[str] | None = None,
161
) -> int:
162
"""Add a request handler to the network.
163
164
Args:
165
event: The event to subscribe to.
166
callback: The callback function to execute on request interception.
167
Takes Request object as argument.
168
url_patterns: A list of URL patterns to intercept. Default is None.
169
contexts: A list of contexts to intercept. Default is None.
170
171
Returns:
172
int: callback id
173
"""
174
try:
175
event_name = self.EVENTS[event]
176
phase_name = self.PHASES[event]
177
except KeyError:
178
raise Exception(f"Event {event} not found")
179
180
result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts)
181
callback_id = self._on_request(event_name, callback)
182
183
if event_name in self.subscriptions:
184
self.subscriptions[event_name].append(callback_id)
185
else:
186
params: dict[str, Any] = {}
187
params["events"] = [event_name]
188
self.conn.execute(command_builder("session.subscribe", params))
189
self.subscriptions[event_name] = [callback_id]
190
191
self.callbacks[callback_id] = result["intercept"]
192
return callback_id
193
194
def remove_request_handler(self, event: str, callback_id: int) -> None:
195
"""Remove a request handler from the network.
196
197
Args:
198
event: The event to unsubscribe from.
199
callback_id: The callback id to remove.
200
"""
201
try:
202
event_name = self.EVENTS[event]
203
except KeyError:
204
raise Exception(f"Event {event} not found")
205
206
net_event = NetworkEvent(event_name)
207
208
self.conn.remove_callback(net_event, callback_id)
209
self._remove_intercept(self.callbacks[callback_id])
210
del self.callbacks[callback_id]
211
self.subscriptions[event_name].remove(callback_id)
212
if len(self.subscriptions[event_name]) == 0:
213
params: dict[str, Any] = {}
214
params["events"] = [event_name]
215
self.conn.execute(command_builder("session.unsubscribe", params))
216
del self.subscriptions[event_name]
217
218
def clear_request_handlers(self) -> None:
219
"""Clear all request handlers from the network."""
220
for event_name in self.subscriptions:
221
net_event = NetworkEvent(event_name)
222
for callback_id in self.subscriptions[event_name]:
223
self.conn.remove_callback(net_event, callback_id)
224
self._remove_intercept(self.callbacks[callback_id])
225
del self.callbacks[callback_id]
226
params: dict[str, Any] = {}
227
params["events"] = [event_name]
228
self.conn.execute(command_builder("session.unsubscribe", params))
229
self.subscriptions = {}
230
231
def add_auth_handler(self, username: str, password: str) -> int:
232
"""Add an authentication handler to the network.
233
234
Args:
235
username: The username to authenticate with.
236
password: The password to authenticate with.
237
238
Returns:
239
int: callback id
240
"""
241
event = "auth_required"
242
243
def _callback(request: Request) -> None:
244
request._continue_with_auth(username, password)
245
246
return self.add_request_handler(event, _callback)
247
248
def remove_auth_handler(self, callback_id: int) -> None:
249
"""Remove an authentication handler from the network.
250
251
Args:
252
callback_id: The callback id to remove.
253
"""
254
event = "auth_required"
255
self.remove_request_handler(event, callback_id)
256
257
258
class Request:
259
"""Represents an intercepted network request."""
260
261
def __init__(
262
self,
263
network: Network,
264
request_id: Any,
265
body_size: int | None = None,
266
cookies: Any = None,
267
resource_type: str | None = None,
268
headers: Any = None,
269
headers_size: int | None = None,
270
method: str | None = None,
271
timings: Any = None,
272
url: str | None = None,
273
) -> None:
274
self.network = network
275
self.request_id = request_id
276
self.body_size = body_size
277
self.cookies = cookies
278
self.resource_type = resource_type
279
self.headers = headers
280
self.headers_size = headers_size
281
self.method = method
282
self.timings = timings
283
self.url = url
284
285
def fail_request(self) -> None:
286
"""Fail this request."""
287
if not self.request_id:
288
raise ValueError("Request not found.")
289
290
params: dict[str, Any] = {"request": self.request_id}
291
self.network.conn.execute(command_builder("network.failRequest", params))
292
293
def continue_request(
294
self,
295
body: Any = None,
296
method: str | None = None,
297
headers: Any = None,
298
cookies: Any = None,
299
url: str | None = None,
300
) -> None:
301
"""Continue after intercepting this request."""
302
if not self.request_id:
303
raise ValueError("Request not found.")
304
305
params: dict[str, Any] = {"request": self.request_id}
306
if body is not None:
307
params["body"] = body
308
if method is not None:
309
params["method"] = method
310
if headers is not None:
311
params["headers"] = headers
312
if cookies is not None:
313
params["cookies"] = cookies
314
if url is not None:
315
params["url"] = url
316
317
self.network.conn.execute(command_builder("network.continueRequest", params))
318
319
def _continue_with_auth(self, username: str | None = None, password: str | None = None) -> None:
320
"""Continue with authentication.
321
322
Args:
323
username: The username to authenticate with.
324
password: The password to authenticate with.
325
326
Note:
327
If username or password is None, it attempts auth with no credentials.
328
"""
329
params: dict[str, Any] = {}
330
params["request"] = self.request_id
331
332
if not username or not password: # no credentials is valid option
333
params["action"] = "default"
334
else:
335
params["action"] = "provideCredentials"
336
params["credentials"] = {"type": "password", "username": username, "password": password}
337
338
self.network.conn.execute(command_builder("network.continueWithAuth", params))
339
340