Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
SeleniumHQ
GitHub Repository: SeleniumHQ/Selenium
Path: blob/trunk/py/selenium/webdriver/common/api_request_context.py
10193 views
1
# Licensed to the Software Freedom Conservancy (SFC) under one
2
# or more contributor license agreements. See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership. The SFC licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License. You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied. See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
"""APIRequestContext for making HTTP requests with browser cookie synchronization."""
19
20
import json
21
import logging
22
import pathlib
23
import time
24
import urllib.parse
25
from email.utils import parsedate_to_datetime
26
from http.client import responses as http_status_phrases
27
from typing import TYPE_CHECKING, Any
28
29
import urllib3
30
from urllib3.util.retry import Retry
31
32
if TYPE_CHECKING:
33
from selenium.webdriver.remote.webdriver import WebDriver
34
35
logger = logging.getLogger(__name__)
36
37
38
class APIRequestFailure(Exception):
39
"""Raised when an API request returns a non-2xx status and fail_on_status_code is True.
40
41
Attributes:
42
response: The APIResponse that triggered the failure.
43
"""
44
45
def __init__(self, response: "APIResponse") -> None:
46
self.response = response
47
super().__init__(f"{response.status} {response.status_text}: {response.url}")
48
49
50
class APIResponse:
51
"""Represents an HTTP response from an API request.
52
53
Attributes:
54
status: HTTP status code.
55
status_text: HTTP status text.
56
headers: Response headers as a dict.
57
url: The request URL.
58
"""
59
60
def __init__(self, status: int, status_text: str, headers: dict[str, str], url: str, body: bytes) -> None:
61
self.status = status
62
self.status_text = status_text
63
self.headers = headers
64
self.url = url
65
self._body = body
66
67
@property
68
def ok(self) -> bool:
69
"""Whether the response status is in the 200-299 range."""
70
return 200 <= self.status <= 299
71
72
def json(self) -> Any:
73
"""Parse the response body as JSON.
74
75
Returns:
76
The parsed JSON object.
77
"""
78
return json.loads(self._body)
79
80
def text(self) -> str:
81
"""Decode the response body as UTF-8 text.
82
83
Returns:
84
The response body as a string.
85
"""
86
return self._body.decode("utf-8")
87
88
def body(self) -> bytes:
89
"""Return the raw response body bytes.
90
91
Returns:
92
The response body as bytes.
93
"""
94
return self._body
95
96
def dispose(self) -> None:
97
"""Free the response body memory."""
98
self._body = b""
99
100
101
def _cookie_matches(cookie: dict, url: str, default_domain: str = "") -> bool:
102
"""Check if a browser cookie should be sent with a request to the given URL.
103
104
Evaluates expiry, domain, path, and secure attribute matching per RFC 6265.
105
106
Args:
107
cookie: A cookie dict from driver.get_cookies().
108
url: The target request URL.
109
default_domain: Fallback domain for host-only cookies (no domain attribute).
110
When a cookie has no domain, it only matches if the request hostname
111
equals this value. If empty and cookie has no domain, the cookie is skipped.
112
113
Returns:
114
True if the cookie matches the URL.
115
"""
116
# Expiry check — skip expired cookies
117
expiry = cookie.get("expiry")
118
if expiry is not None and expiry <= int(time.time()):
119
return False
120
121
parsed = urllib.parse.urlparse(url)
122
hostname = parsed.hostname or ""
123
path = parsed.path or "/"
124
scheme = parsed.scheme or "http"
125
126
# Domain matching (RFC 6265 section 5.1.3)
127
cookie_domain = cookie.get("domain", "")
128
if not cookie_domain:
129
# Host-only cookie — must match the origin host exactly
130
if not default_domain or hostname != default_domain:
131
return False
132
elif cookie_domain.startswith("."):
133
# .example.com matches example.com and sub.example.com
134
if not (hostname == cookie_domain[1:] or hostname.endswith(cookie_domain)):
135
return False
136
else:
137
if hostname != cookie_domain:
138
return False
139
140
# Path matching (RFC 6265 section 5.1.4)
141
cookie_path = cookie.get("path", "/")
142
if cookie_path == "/":
143
pass # root path matches everything
144
elif path != cookie_path and not path.startswith(cookie_path + "/"):
145
return False
146
147
# Secure matching
148
if cookie.get("secure", False) and scheme != "https":
149
return False
150
151
return True
152
153
154
def _parse_set_cookie(header_value: str) -> dict:
155
"""Parse a single Set-Cookie header value into a cookie dict.
156
157
Uses manual parsing instead of http.cookies.SimpleCookie which is too
158
strict for real-world Set-Cookie headers.
159
160
Args:
161
header_value: The Set-Cookie header string.
162
163
Returns:
164
A dict with cookie attributes suitable for driver.add_cookie().
165
"""
166
parts = header_value.split(";")
167
name_value = parts[0].strip()
168
eq_idx = name_value.find("=")
169
if eq_idx == -1:
170
return {}
171
name = name_value[:eq_idx].strip()
172
value = name_value[eq_idx + 1 :].strip()
173
174
cookie: dict[str, Any] = {"name": name, "value": value}
175
has_max_age = False
176
177
for part in parts[1:]:
178
part = part.strip()
179
if not part:
180
continue
181
if "=" in part:
182
attr_name, attr_value = part.split("=", 1)
183
attr_name = attr_name.strip().lower()
184
attr_value = attr_value.strip()
185
else:
186
attr_name = part.strip().lower()
187
attr_value = ""
188
189
if attr_name == "domain":
190
cookie["domain"] = attr_value
191
elif attr_name == "path":
192
cookie["path"] = attr_value
193
elif attr_name == "secure":
194
cookie["secure"] = True
195
elif attr_name == "httponly":
196
cookie["httpOnly"] = True
197
elif attr_name == "samesite":
198
cookie["sameSite"] = attr_value
199
elif attr_name == "max-age":
200
try:
201
max_age = int(attr_value)
202
cookie["expiry"] = int(time.time()) + max_age
203
has_max_age = True
204
except ValueError:
205
pass
206
elif attr_name == "expires" and not has_max_age:
207
# RFC 6265 §5.3: Max-Age takes precedence over Expires
208
try:
209
dt = parsedate_to_datetime(attr_value)
210
cookie["expiry"] = int(dt.timestamp())
211
except (ValueError, TypeError):
212
pass
213
214
return cookie
215
216
217
def _get_set_cookie_headers(resp: urllib3.BaseHTTPResponse) -> list[str]:
218
"""Extract all Set-Cookie header values from a urllib3 response.
219
220
Args:
221
resp: The urllib3 HTTP response.
222
223
Returns:
224
A list of Set-Cookie header strings.
225
"""
226
if hasattr(resp.headers, "getlist"):
227
headers = resp.headers.getlist("Set-Cookie")
228
if headers:
229
return headers
230
sc = resp.headers.get("Set-Cookie")
231
return [sc] if sc else []
232
233
234
def _resolve_redirect_url(resp: urllib3.BaseHTTPResponse, original_url: str) -> str:
235
"""Return the final URL after any redirects.
236
237
urllib3's retry history records each hop. When redirects occurred,
238
the last entry's redirect_location resolved against its URL gives
239
the final destination. When no redirects occurred, the original
240
request URL is returned unchanged.
241
"""
242
history = resp.retries.history if resp.retries else ()
243
if history:
244
last = history[-1]
245
if last.url and last.redirect_location:
246
return urllib.parse.urljoin(last.url, last.redirect_location)
247
return original_url
248
249
250
class _BaseRequestContext:
251
"""Base class with shared HTTP request logic for API request contexts."""
252
253
def __init__(
254
self,
255
base_url: str = "",
256
extra_headers: dict[str, str] | None = None,
257
timeout: float = 30.0,
258
max_redirects: int = 10,
259
fail_on_status_code: bool = False,
260
) -> None:
261
self._base_url = base_url
262
self._extra_headers = extra_headers or {}
263
self._timeout = timeout
264
self._max_redirects = max_redirects
265
self._fail_on_status_code = fail_on_status_code
266
self._pool = urllib3.PoolManager()
267
268
def get(self, url: str, **kwargs: Any) -> APIResponse:
269
"""Send a GET request.
270
271
Args:
272
url: The request URL (absolute or relative to base_url).
273
**kwargs: Optional arguments: headers, params, timeout, max_redirects, fail_on_status_code.
274
275
Returns:
276
An APIResponse object.
277
"""
278
return self._fetch(url, "GET", **kwargs)
279
280
def post(self, url: str, **kwargs: Any) -> APIResponse:
281
"""Send a POST request.
282
283
Args:
284
url: The request URL (absolute or relative to base_url).
285
**kwargs: Optional arguments: headers, params, data, form,
286
json_data, timeout, max_redirects, fail_on_status_code.
287
288
Returns:
289
An APIResponse object.
290
"""
291
return self._fetch(url, "POST", **kwargs)
292
293
def put(self, url: str, **kwargs: Any) -> APIResponse:
294
"""Send a PUT request.
295
296
Args:
297
url: The request URL (absolute or relative to base_url).
298
**kwargs: Optional arguments: headers, params, data, form,
299
json_data, timeout, max_redirects, fail_on_status_code.
300
301
Returns:
302
An APIResponse object.
303
"""
304
return self._fetch(url, "PUT", **kwargs)
305
306
def patch(self, url: str, **kwargs: Any) -> APIResponse:
307
"""Send a PATCH request.
308
309
Args:
310
url: The request URL (absolute or relative to base_url).
311
**kwargs: Optional arguments: headers, params, data, form,
312
json_data, timeout, max_redirects, fail_on_status_code.
313
314
Returns:
315
An APIResponse object.
316
"""
317
return self._fetch(url, "PATCH", **kwargs)
318
319
def delete(self, url: str, **kwargs: Any) -> APIResponse:
320
"""Send a DELETE request.
321
322
Args:
323
url: The request URL (absolute or relative to base_url).
324
**kwargs: Optional arguments: headers, params, data, form,
325
json_data, timeout, max_redirects, fail_on_status_code.
326
327
Returns:
328
An APIResponse object.
329
"""
330
return self._fetch(url, "DELETE", **kwargs)
331
332
def head(self, url: str, **kwargs: Any) -> APIResponse:
333
"""Send a HEAD request.
334
335
Args:
336
url: The request URL (absolute or relative to base_url).
337
**kwargs: Optional arguments: headers, params, timeout,
338
max_redirects, fail_on_status_code.
339
340
Returns:
341
An APIResponse object.
342
"""
343
return self._fetch(url, "HEAD", **kwargs)
344
345
def fetch(self, url: str, method: str = "GET", **kwargs: Any) -> APIResponse:
346
"""Send an HTTP request with a custom method.
347
348
Args:
349
url: The request URL (absolute or relative to base_url).
350
method: The HTTP method to use.
351
**kwargs: Optional arguments: headers, params, data, form,
352
json_data, timeout, max_redirects, fail_on_status_code.
353
354
Returns:
355
An APIResponse object.
356
"""
357
return self._fetch(url, method, **kwargs)
358
359
def dispose(self) -> None:
360
"""Close the underlying connection pool."""
361
self._pool.clear()
362
363
def _resolve_url(self, url: str) -> str:
364
"""Resolve a URL, prepending base_url for relative paths."""
365
if not url.startswith(("http://", "https://")):
366
return self._base_url.rstrip("/") + "/" + url.lstrip("/")
367
return url
368
369
def _build_headers(self, kwargs: dict[str, Any]) -> dict[str, str]:
370
"""Merge extra_headers with per-request headers."""
371
headers = dict(self._extra_headers)
372
if kwargs.get("headers"):
373
headers.update(kwargs["headers"])
374
return headers
375
376
def _prepare_body(self, headers: dict[str, str], kwargs: dict[str, Any]) -> bytes | None:
377
"""Prepare the request body from json_data, form, or data kwargs.
378
379
Priority: json_data > form > data. Only one should be provided.
380
"""
381
json_data = kwargs.get("json_data")
382
form = kwargs.get("form")
383
data = kwargs.get("data")
384
385
if json_data is not None:
386
headers.setdefault("Content-Type", "application/json")
387
return json.dumps(json_data).encode("utf-8")
388
elif form is not None:
389
headers.setdefault("Content-Type", "application/x-www-form-urlencoded")
390
return urllib.parse.urlencode(form).encode("utf-8")
391
elif data is not None:
392
if isinstance(data, dict):
393
headers.setdefault("Content-Type", "application/x-www-form-urlencoded")
394
return urllib.parse.urlencode(data).encode("utf-8")
395
elif isinstance(data, str):
396
return data.encode("utf-8")
397
elif isinstance(data, bytes):
398
return data
399
return None
400
401
def _append_params(self, url: str, kwargs: dict[str, Any]) -> str:
402
"""Append query parameters to the URL."""
403
params = kwargs.get("params")
404
if params:
405
separator = "&" if "?" in url else "?"
406
return url + separator + urllib.parse.urlencode(params)
407
return url
408
409
def _execute_request(
410
self, method: str, url: str, headers: dict[str, str], body: bytes | None, kwargs: dict[str, Any]
411
) -> urllib3.BaseHTTPResponse:
412
"""Execute the HTTP request via urllib3."""
413
timeout = kwargs.get("timeout", self._timeout)
414
max_redirects = kwargs.get("max_redirects", self._max_redirects)
415
416
follow = max_redirects > 0
417
retries = Retry(
418
connect=0,
419
read=0,
420
status=0,
421
other=0,
422
redirect=max_redirects if follow else 0,
423
raise_on_redirect=False,
424
)
425
426
return self._pool.request(
427
method,
428
url,
429
headers=headers,
430
body=body,
431
timeout=timeout,
432
redirect=follow,
433
retries=retries,
434
preload_content=True,
435
)
436
437
def _build_response(self, resp: urllib3.BaseHTTPResponse, url: str) -> APIResponse:
438
"""Build an APIResponse from a urllib3 response."""
439
# Merge duplicate headers per RFC 7230 §3.2.2 (combine with ", ")
440
resp_headers: dict[str, str] = {}
441
for k, v in resp.headers.items():
442
key = k.lower()
443
if key in resp_headers:
444
resp_headers[key] = resp_headers[key] + ", " + v
445
else:
446
resp_headers[key] = v
447
# urllib3 2.x removed resp.reason; fall back to stdlib phrase lookup
448
reason = getattr(resp, "reason", None)
449
status_text = reason or http_status_phrases.get(resp.status, "")
450
return APIResponse(
451
status=resp.status,
452
status_text=status_text,
453
headers=resp_headers,
454
url=url,
455
body=resp.data,
456
)
457
458
def _get_cookies_for_request(self, url: str) -> list[dict]:
459
"""Get cookies that should be sent with the request. Overridden by subclasses."""
460
return []
461
462
def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None:
463
"""Process Set-Cookie headers from the response. Overridden by subclasses."""
464
465
def _fetch(self, url: str, method: str, **kwargs: Any) -> APIResponse:
466
"""Execute an HTTP request with cookie handling.
467
468
Args:
469
url: The request URL.
470
method: The HTTP method.
471
**kwargs: Optional arguments.
472
473
Returns:
474
An APIResponse object.
475
"""
476
url = self._resolve_url(url)
477
headers = self._build_headers(kwargs)
478
479
# Apply cookies
480
matching_cookies = self._get_cookies_for_request(url)
481
if matching_cookies:
482
cookie_header = "; ".join(f"{c['name']}={c['value']}" for c in matching_cookies)
483
if "Cookie" in headers:
484
headers["Cookie"] = headers["Cookie"] + "; " + cookie_header
485
else:
486
headers["Cookie"] = cookie_header
487
488
body = self._prepare_body(headers, kwargs)
489
url = self._append_params(url, kwargs)
490
resp = self._execute_request(method, url, headers, body, kwargs)
491
492
# After redirects, associate cookies with the final destination's
493
# origin, not the initial request URL.
494
final_url = _resolve_redirect_url(resp, url)
495
496
# Process response cookies
497
set_cookie_headers = _get_set_cookie_headers(resp)
498
if set_cookie_headers:
499
self._handle_response_cookies(set_cookie_headers, final_url)
500
501
response = self._build_response(resp, final_url)
502
503
fail = kwargs.get("fail_on_status_code", self._fail_on_status_code)
504
if fail and not response.ok:
505
raise APIRequestFailure(response)
506
507
return response
508
509
510
class APIRequestContext(_BaseRequestContext):
511
"""Makes HTTP requests with automatic browser cookie synchronization.
512
513
Cookies from the browser session are sent with API requests, and cookies
514
from API responses are synced back to the browser.
515
516
Args:
517
driver: The WebDriver instance to sync cookies with.
518
base_url: Optional base URL prepended to relative request paths.
519
extra_headers: Optional headers included in every request.
520
timeout: Default request timeout in seconds.
521
max_redirects: Maximum number of redirects to follow.
522
fail_on_status_code: If True, raise APIRequestFailure for non-2xx responses.
523
"""
524
525
def __init__(
526
self,
527
driver: "WebDriver",
528
base_url: str = "",
529
extra_headers: dict[str, str] | None = None,
530
timeout: float = 30.0,
531
max_redirects: int = 10,
532
fail_on_status_code: bool = False,
533
) -> None:
534
super().__init__(
535
base_url=base_url,
536
extra_headers=extra_headers,
537
timeout=timeout,
538
max_redirects=max_redirects,
539
fail_on_status_code=fail_on_status_code,
540
)
541
self._driver = driver
542
543
def new_context(
544
self,
545
base_url: str = "",
546
extra_headers: dict[str, str] | None = None,
547
storage_state: dict | str | pathlib.Path | None = None,
548
fail_on_status_code: bool = False,
549
) -> "_IsolatedAPIRequestContext":
550
"""Create an isolated API request context that does not sync with the browser.
551
552
Args:
553
base_url: Optional base URL for this context.
554
extra_headers: Optional headers for this context.
555
storage_state: Optional cookies to pre-load, as a dict, JSON file path, or Path.
556
fail_on_status_code: If True, raise APIRequestFailure for non-2xx responses.
557
558
Returns:
559
An _IsolatedAPIRequestContext instance.
560
"""
561
cookies: list[dict] = []
562
if storage_state is not None:
563
if isinstance(storage_state, (str, pathlib.Path)):
564
file_path = pathlib.Path(storage_state)
565
if not file_path.exists():
566
raise FileNotFoundError(f"Storage state file not found: {file_path}")
567
try:
568
with open(file_path) as f:
569
state = json.load(f)
570
except json.JSONDecodeError as e:
571
raise ValueError(f"Invalid JSON in storage state file {file_path}: {e}") from e
572
except OSError as e:
573
raise OSError(f"Cannot read storage state file {file_path}: {e}") from e
574
else:
575
state = storage_state
576
cookies = list(state.get("cookies", []))
577
578
return _IsolatedAPIRequestContext(
579
base_url=base_url,
580
extra_headers=extra_headers,
581
cookies=cookies,
582
timeout=self._timeout,
583
max_redirects=self._max_redirects,
584
fail_on_status_code=fail_on_status_code,
585
)
586
587
def get_storage_state(self, path: str | pathlib.Path | None = None) -> dict[str, Any]:
588
"""Export the current browser cookies as a storage state dict.
589
590
Args:
591
path: Optional file path to save the storage state as JSON.
592
593
Returns:
594
A dict with a "cookies" key containing the browser cookies.
595
"""
596
cookies = self._driver.get_cookies()
597
state: dict[str, Any] = {"cookies": cookies}
598
if path is not None:
599
file_path = pathlib.Path(path)
600
try:
601
with open(file_path, "w") as f:
602
json.dump(state, f, indent=2)
603
except OSError as e:
604
raise OSError(f"Cannot write storage state to {file_path}: {e}") from e
605
return state
606
607
def _get_cookies_for_request(self, url: str) -> list[dict]:
608
"""Get matching browser cookies for the request URL."""
609
try:
610
browser_cookies = self._driver.get_cookies()
611
except Exception:
612
logger.debug("Could not retrieve browser cookies", exc_info=True)
613
return []
614
# Derive default domain from the browser's current page for host-only cookies
615
default_domain = ""
616
try:
617
current = self._driver.current_url
618
if current:
619
default_domain = urllib.parse.urlparse(current).hostname or ""
620
except Exception:
621
logger.debug("Could not get current URL for host-only cookie matching", exc_info=True)
622
return [c for c in browser_cookies if _cookie_matches(c, url, default_domain)]
623
624
def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None:
625
"""Sync Set-Cookie headers back to the browser."""
626
parsed_url = urllib.parse.urlparse(url)
627
for sc_header in set_cookie_headers:
628
cookie = _parse_set_cookie(sc_header)
629
if not cookie.get("name"):
630
continue
631
cookie.setdefault("domain", parsed_url.hostname or "")
632
cookie.setdefault("path", "/")
633
expiry = cookie.get("expiry")
634
if expiry is not None and expiry <= int(time.time()):
635
try:
636
self._driver.delete_cookie(cookie["name"])
637
except Exception:
638
pass
639
continue
640
try:
641
self._driver.add_cookie(cookie)
642
except Exception:
643
logger.warning(
644
"Could not sync cookie '%s' to browser (domain mismatch with current page)",
645
cookie.get("name"),
646
exc_info=True,
647
)
648
649
650
class _IsolatedAPIRequestContext(_BaseRequestContext):
651
"""An isolated API request context that maintains its own cookie jar.
652
653
Does not synchronize cookies with any browser session.
654
"""
655
656
def __init__(
657
self,
658
base_url: str = "",
659
extra_headers: dict[str, str] | None = None,
660
cookies: list[dict] | None = None,
661
timeout: float = 30.0,
662
max_redirects: int = 10,
663
fail_on_status_code: bool = False,
664
) -> None:
665
super().__init__(
666
base_url=base_url,
667
extra_headers=extra_headers,
668
timeout=timeout,
669
max_redirects=max_redirects,
670
fail_on_status_code=fail_on_status_code,
671
)
672
self._cookies: list[dict] = cookies or []
673
674
def get_storage_state(self) -> dict[str, Any]:
675
"""Return the current cookies as a storage state dict."""
676
return {"cookies": list(self._cookies)}
677
678
def _get_cookies_for_request(self, url: str) -> list[dict]:
679
"""Get matching cookies from the internal jar."""
680
# For isolated contexts, use the request hostname as default domain
681
default_domain = urllib.parse.urlparse(url).hostname or ""
682
return [c for c in self._cookies if _cookie_matches(c, url, default_domain)]
683
684
def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None:
685
"""Store Set-Cookie headers in the internal jar."""
686
parsed_url = urllib.parse.urlparse(url)
687
now = int(time.time())
688
for sc_header in set_cookie_headers:
689
cookie = _parse_set_cookie(sc_header)
690
if not cookie.get("name"):
691
continue
692
cookie.setdefault("domain", parsed_url.hostname or "")
693
cookie.setdefault("path", "/")
694
# Cookies are unique by (name, domain, path)
695
key = (cookie["name"], cookie.get("domain", ""), cookie.get("path", "/"))
696
# Remove existing cookie with same key
697
self._cookies = [
698
c for c in self._cookies if (c.get("name"), c.get("domain", ""), c.get("path", "/")) != key
699
]
700
# Only store if not expired (Max-Age=0 or negative means delete)
701
expiry = cookie.get("expiry")
702
if expiry is not None and expiry <= now:
703
continue
704
self._cookies.append(cookie)
705
706