Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/cloud/test_credential_provider.py
8424 views
1
import io
2
import pickle
3
import sys
4
from datetime import datetime, timezone
5
from pathlib import Path
6
from typing import Any
7
from unittest.mock import Mock
8
9
import pytest
10
11
import polars as pl
12
import polars.io.cloud.credential_provider
13
from polars.io.cloud._utils import NoPickleOption
14
from polars.io.cloud.credential_provider._builder import (
15
AutoInit,
16
CredentialProviderBuilder,
17
_init_credential_provider_builder,
18
)
19
from polars.io.cloud.credential_provider._providers import (
20
CachedCredentialProvider,
21
CachingCredentialProvider,
22
UserProvidedGCPToken,
23
)
24
from tests.conftest import PlMonkeyPatch
25
26
27
@pytest.mark.parametrize(
28
"io_func",
29
[
30
*[pl.scan_parquet, pl.read_parquet],
31
pl.scan_csv,
32
*[pl.scan_ndjson, pl.read_ndjson],
33
pl.scan_ipc,
34
],
35
)
36
def test_credential_provider_scan(io_func: Any, plmonkeypatch: PlMonkeyPatch) -> None:
37
err_magic = "err_magic_3"
38
39
def raises(*_: None, **__: None) -> None:
40
raise AssertionError(err_magic)
41
42
from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder
43
44
plmonkeypatch.setattr(CredentialProviderBuilder, "__init__", raises)
45
46
with pytest.raises(AssertionError, match=err_magic):
47
io_func("s3://bucket/path", credential_provider="auto")
48
49
with pytest.raises(AssertionError, match=err_magic):
50
io_func(
51
"s3://bucket/path",
52
credential_provider="auto",
53
storage_options={"aws_region": "eu-west-1"},
54
)
55
56
# We can't test these with the `read_` functions as they end up executing
57
# the query
58
if io_func.__name__.startswith("scan_"):
59
# Passing `None` should disable the automatic instantiation of
60
# `CredentialProviderAWS`
61
io_func("s3://bucket/path", credential_provider=None)
62
63
err_magic = "err_magic_7"
64
65
def raises_2() -> pl.CredentialProviderFunctionReturn:
66
raise AssertionError(err_magic)
67
68
with pytest.raises(AssertionError, match=err_magic):
69
io_func("s3://bucket/path", credential_provider=raises_2).collect()
70
71
72
@pytest.mark.parametrize(
73
("provider_class", "path"),
74
[
75
(polars.io.cloud.credential_provider.CredentialProviderAWS, "s3://.../..."),
76
(polars.io.cloud.credential_provider.CredentialProviderGCP, "gs://.../..."),
77
(polars.io.cloud.credential_provider.CredentialProviderAzure, "az://.../..."),
78
],
79
)
80
def test_credential_provider_serialization_auto_init(
81
provider_class: polars.io.cloud.credential_provider.CredentialProvider,
82
path: str,
83
plmonkeypatch: PlMonkeyPatch,
84
) -> None:
85
def raises_1(*a: Any, **kw: Any) -> None:
86
msg = "err_magic_1"
87
raise AssertionError(msg)
88
89
plmonkeypatch.setattr(provider_class, "__init__", raises_1)
90
91
# If this is not set we will get an error before hitting the credential
92
# provider logic when polars attempts to retrieve the region from AWS.
93
plmonkeypatch.setenv("AWS_REGION", "eu-west-1")
94
95
# Credential provider should not be initialized during query plan construction.
96
q = pl.scan_parquet(path)
97
98
# Check baseline - query plan is configured to auto-initialize the credential
99
# provider.
100
with pytest.raises(AssertionError, match="err_magic_1"):
101
q.collect()
102
103
q = pickle.loads(pickle.dumps(q))
104
105
def raises_2(*a: Any, **kw: Any) -> None:
106
msg = "err_magic_2"
107
raise AssertionError(msg)
108
109
plmonkeypatch.setattr(provider_class, "__init__", raises_2)
110
111
# Check that auto-initialization happens upon executing the deserialized
112
# query.
113
with pytest.raises(AssertionError, match="err_magic_2"):
114
q.collect()
115
116
117
@pytest.mark.slow
118
def test_credential_provider_serialization_custom_provider() -> None:
119
err_magic = "err_magic_3"
120
121
class ErrCredentialProvider(pl.CredentialProvider):
122
def __call__(self) -> pl.CredentialProviderFunctionReturn:
123
raise AssertionError(err_magic)
124
125
lf = pl.scan_parquet(
126
"s3://bucket/path", credential_provider=ErrCredentialProvider()
127
)
128
129
serialized = lf.serialize()
130
131
lf = pl.LazyFrame.deserialize(io.BytesIO(serialized))
132
133
with pytest.raises(AssertionError, match=err_magic):
134
lf.collect()
135
136
137
def test_credential_provider_gcp_skips_config_autoload(
138
plmonkeypatch: PlMonkeyPatch,
139
) -> None:
140
plmonkeypatch.setenv("GOOGLE_SERVICE_ACCOUNT_PATH", "__non_existent")
141
142
with pytest.raises(OSError, match="__non_existent"):
143
pl.scan_parquet("gs://.../...", credential_provider=None).collect()
144
145
err_magic = "err_magic_3"
146
147
def raises() -> pl.CredentialProviderFunctionReturn:
148
raise AssertionError(err_magic)
149
150
with pytest.raises(AssertionError, match=err_magic):
151
pl.scan_parquet("gs://.../...", credential_provider=raises).collect()
152
153
154
def test_credential_provider_aws_import_error_with_requested_profile(
155
plmonkeypatch: PlMonkeyPatch,
156
) -> None:
157
def _session(self: Any) -> None:
158
msg = "err_magic_3"
159
raise ImportError(msg)
160
161
plmonkeypatch.setattr(pl.CredentialProviderAWS, "_session", _session)
162
plmonkeypatch.setenv("AWS_REGION", "eu-west-1")
163
164
q = pl.scan_parquet(
165
"s3://.../...",
166
credential_provider=pl.CredentialProviderAWS(profile_name="test_profile"),
167
)
168
169
with pytest.raises(
170
pl.exceptions.ComputeError,
171
match=(
172
"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"
173
),
174
):
175
q.collect()
176
177
q = pl.scan_parquet(
178
"s3://.../...",
179
storage_options={"aws_profile": "test_profile"},
180
)
181
182
with pytest.raises(
183
pl.exceptions.ComputeError,
184
match=(
185
"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"
186
),
187
):
188
q.collect()
189
190
191
@pytest.mark.slow
192
@pytest.mark.write_disk
193
def test_credential_provider_aws_endpoint_url_scan_no_parameters(
194
tmp_path: Path,
195
plmonkeypatch: PlMonkeyPatch,
196
capfd: pytest.CaptureFixture[str],
197
) -> None:
198
tmp_path.mkdir(exist_ok=True)
199
200
_set_default_credentials(tmp_path, plmonkeypatch)
201
cfg_file_path = tmp_path / "config"
202
203
plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))
204
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
205
206
cfg_file_path.write_text("""\
207
[default]
208
endpoint_url = http://localhost:333
209
""")
210
211
# Scan with no parameters should load via CredentialProviderAWS
212
q = pl.scan_parquet("s3://.../...")
213
214
capfd.readouterr()
215
216
with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):
217
q.collect()
218
219
capture = capfd.readouterr().err
220
lines = capture.splitlines()
221
222
assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines
223
224
225
@pytest.mark.slow
226
@pytest.mark.write_disk
227
def test_credential_provider_aws_endpoint_url_serde(
228
tmp_path: Path,
229
plmonkeypatch: PlMonkeyPatch,
230
capfd: pytest.CaptureFixture[str],
231
) -> None:
232
tmp_path.mkdir(exist_ok=True)
233
234
_set_default_credentials(tmp_path, plmonkeypatch)
235
cfg_file_path = tmp_path / "config"
236
237
plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))
238
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
239
240
cfg_file_path.write_text("""\
241
[default]
242
endpoint_url = http://localhost:333
243
""")
244
245
q = pl.scan_parquet("s3://.../...")
246
q = pickle.loads(pickle.dumps(q))
247
248
cfg_file_path.write_text("""\
249
[default]
250
endpoint_url = http://localhost:777
251
""")
252
253
capfd.readouterr()
254
255
with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):
256
q.collect()
257
258
259
@pytest.mark.slow
260
@pytest.mark.write_disk
261
def test_credential_provider_aws_endpoint_url_with_storage_options(
262
tmp_path: Path,
263
plmonkeypatch: PlMonkeyPatch,
264
capfd: pytest.CaptureFixture[str],
265
) -> None:
266
tmp_path.mkdir(exist_ok=True)
267
268
_set_default_credentials(tmp_path, plmonkeypatch)
269
cfg_file_path = tmp_path / "config"
270
271
plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))
272
plmonkeypatch.setenv("POLARS_VERBOSE", "1")
273
274
cfg_file_path.write_text("""\
275
[default]
276
endpoint_url = http://localhost:333
277
""")
278
279
# Previously we would not initialize a credential provider at all if secrets
280
# were given under `storage_options`. Now we always initialize so that we
281
# can load the `endpoint_url`, but we decide at the very last second whether
282
# to also retrieve secrets using the credential provider.
283
q = pl.scan_parquet(
284
"s3://.../...",
285
storage_options={
286
"aws_access_key_id": "...",
287
"aws_secret_access_key": "...",
288
},
289
)
290
291
with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):
292
q.collect()
293
294
capture = capfd.readouterr().err
295
lines = capture.splitlines()
296
297
assert (
298
"[CredentialProviderAWS]: Will not be used as a provider: unhandled key "
299
"in storage_options: 'aws_secret_access_key'"
300
) in lines
301
assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines
302
303
304
@pytest.mark.parametrize(
305
"storage_options",
306
[
307
{"aws_endpoint_url": "http://localhost:777"},
308
{
309
"aws_access_key_id": "...",
310
"aws_secret_access_key": "...",
311
"aws_endpoint_url": "http://localhost:777",
312
},
313
],
314
)
315
@pytest.mark.slow
316
@pytest.mark.write_disk
317
def test_credential_provider_aws_endpoint_url_passed_in_storage_options(
318
storage_options: dict[str, str],
319
tmp_path: Path,
320
plmonkeypatch: PlMonkeyPatch,
321
) -> None:
322
tmp_path.mkdir(exist_ok=True)
323
324
_set_default_credentials(tmp_path, plmonkeypatch)
325
cfg_file_path = tmp_path / "config"
326
plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))
327
328
cfg_file_path.write_text("""\
329
[default]
330
endpoint_url = http://localhost:333
331
""")
332
333
q = pl.scan_parquet("s3://.../...")
334
335
with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):
336
q.collect()
337
338
# An endpoint_url passed in `storage_options` should take precedence.
339
q = pl.scan_parquet(
340
"s3://.../...",
341
storage_options=storage_options,
342
)
343
344
with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):
345
q.collect()
346
347
348
def _set_default_credentials(tmp_path: Path, plmonkeypatch: PlMonkeyPatch) -> None:
349
creds_file_path = tmp_path / "credentials"
350
351
creds_file_path.write_text("""\
352
[default]
353
aws_access_key_id=Z
354
aws_secret_access_key=Z
355
""")
356
357
plmonkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(creds_file_path))
358
359
360
@pytest.mark.slow
361
def test_credential_provider_python_builder_cache(
362
plmonkeypatch: PlMonkeyPatch,
363
capfd: pytest.CaptureFixture[str],
364
) -> None:
365
# Tests caching of building credential providers.
366
def dummy_static_aws_credentials(*a: Any, **kw: Any) -> Any:
367
return {
368
"aws_access_key_id": "...",
369
"aws_secret_access_key": "...",
370
}, None
371
372
with plmonkeypatch.context() as cx:
373
provider_init = Mock(wraps=pl.CredentialProviderAWS.__init__)
374
375
cx.setattr(
376
pl.CredentialProviderAWS,
377
"__init__",
378
lambda *a, **kw: provider_init(*a, **kw),
379
)
380
381
cx.setattr(
382
pl.CredentialProviderAWS,
383
"retrieve_credentials_impl",
384
dummy_static_aws_credentials,
385
)
386
387
# Ensure we are building a new query every time.
388
def get_q() -> pl.LazyFrame:
389
return pl.scan_parquet(
390
"s3://.../...",
391
storage_options={
392
"aws_profile": "A",
393
"aws_endpoint_url": "http://localhost",
394
},
395
credential_provider="auto",
396
)
397
398
assert provider_init.call_count == 0
399
400
with pytest.raises(OSError):
401
get_q().collect()
402
403
assert provider_init.call_count == 1
404
405
with pytest.raises(OSError):
406
get_q().collect()
407
408
assert provider_init.call_count == 1
409
410
with pytest.raises(OSError):
411
pl.scan_parquet(
412
"s3://.../...",
413
storage_options={
414
"aws_profile": "B",
415
"aws_endpoint_url": "http://localhost",
416
},
417
credential_provider="auto",
418
).collect()
419
420
assert provider_init.call_count == 2
421
422
with pytest.raises(OSError):
423
get_q().collect()
424
425
assert provider_init.call_count == 2
426
427
cx.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")
428
429
with pytest.raises(OSError):
430
get_q().collect()
431
432
# Note: Increments by 2 due to Rust-side object store rebuilding.
433
434
assert provider_init.call_count == 4
435
436
with pytest.raises(OSError):
437
get_q().collect()
438
439
assert provider_init.call_count == 6
440
441
with plmonkeypatch.context() as cx:
442
cx.setenv("POLARS_VERBOSE", "1")
443
builder = _init_credential_provider_builder(
444
"auto",
445
"s3://.../...",
446
None,
447
"test",
448
)
449
assert builder is not None
450
451
capfd.readouterr()
452
453
builder.build_credential_provider()
454
builder.build_credential_provider()
455
456
capture = capfd.readouterr().err
457
458
# Ensure cache key is memoized on generation
459
assert capture.count("AutoInit cache key") == 1
460
461
pickle.loads(pickle.dumps(builder)).build_credential_provider()
462
463
capture = capfd.readouterr().err
464
465
# Ensure cache key is not serialized
466
assert capture.count("AutoInit cache key") == 1
467
468
469
@pytest.mark.slow
470
def test_credential_provider_python_credentials_cache(
471
plmonkeypatch: PlMonkeyPatch,
472
) -> None:
473
credentials_func = Mock(
474
wraps=lambda: (
475
{
476
"aws_access_key_id": "...",
477
"aws_secret_access_key": "...",
478
},
479
None,
480
)
481
)
482
483
plmonkeypatch.setattr(
484
pl.CredentialProviderAWS,
485
"retrieve_credentials_impl",
486
credentials_func,
487
)
488
489
assert credentials_func.call_count == 0
490
491
provider = pl.CredentialProviderAWS()
492
493
provider()
494
assert credentials_func.call_count == 1
495
496
provider()
497
assert credentials_func.call_count == 1
498
499
plmonkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")
500
501
provider()
502
assert credentials_func.call_count == 2
503
504
provider()
505
assert credentials_func.call_count == 3
506
507
plmonkeypatch.delenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING")
508
509
provider()
510
assert credentials_func.call_count == 4
511
512
provider()
513
assert credentials_func.call_count == 4
514
515
assert provider._cached_credentials.get() is not None
516
assert pickle.loads(pickle.dumps(provider))._cached_credentials.get() is None
517
518
assert provider() == (
519
{
520
"aws_access_key_id": "...",
521
"aws_secret_access_key": "...",
522
},
523
None,
524
)
525
526
provider()[0]["A"] = "A"
527
528
assert provider() == (
529
{
530
"aws_access_key_id": "...",
531
"aws_secret_access_key": "...",
532
},
533
None,
534
)
535
536
537
def test_no_pickle_option() -> None:
538
v = NoPickleOption(3)
539
assert v.get() == 3
540
541
out = pickle.loads(pickle.dumps(v))
542
543
assert out.get() is None
544
545
546
@pytest.mark.write_disk
547
def test_credential_provider_aws_expiry(
548
tmp_path: Path, plmonkeypatch: PlMonkeyPatch
549
) -> None:
550
credential_file_path = tmp_path / "credentials.json"
551
552
credential_file_path.write_text(
553
"""\
554
{
555
"Version": 1,
556
"AccessKeyId": "123",
557
"SecretAccessKey": "456",
558
"SessionToken": "789",
559
"Expiration": "2099-01-01T00:00:00+00:00"
560
}
561
"""
562
)
563
564
cfg_file_path = tmp_path / "config"
565
566
credential_file_path_str = str(credential_file_path).replace("\\", "/")
567
568
cfg_file_path.write_text(f"""\
569
[profile cred_process]
570
credential_process = "{sys.executable}" -c "from pathlib import Path; print(Path('{credential_file_path_str}').read_text())"
571
""")
572
573
plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))
574
575
creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()
576
577
assert creds == {
578
"aws_access_key_id": "123",
579
"aws_secret_access_key": "456",
580
"aws_session_token": "789",
581
}
582
583
assert expiry is not None
584
585
assert datetime.fromtimestamp(expiry, tz=timezone.utc) == datetime.fromisoformat(
586
"2099-01-01T00:00:00+00:00"
587
)
588
589
credential_file_path.write_text(
590
"""\
591
{
592
"Version": 1,
593
"AccessKeyId": "...",
594
"SecretAccessKey": "...",
595
"SessionToken": "..."
596
}
597
"""
598
)
599
600
creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()
601
602
assert creds == {
603
"aws_access_key_id": "...",
604
"aws_secret_access_key": "...",
605
"aws_session_token": "...",
606
}
607
608
assert expiry is None
609
610
611
@pytest.mark.slow
612
@pytest.mark.parametrize(
613
(
614
"credential_provider_class",
615
"scan_path",
616
"initial_credentials",
617
"updated_credentials",
618
),
619
[
620
(
621
pl.CredentialProviderAWS,
622
"s3://.../...",
623
{"aws_access_key_id": "initial", "aws_secret_access_key": "initial"},
624
{"aws_access_key_id": "updated", "aws_secret_access_key": "updated"},
625
),
626
(
627
pl.CredentialProviderAzure,
628
"abfss://container@storage_account.dfs.core.windows.net/bucket",
629
{"bearer_token": "initial"},
630
{"bearer_token": "updated"},
631
),
632
(
633
pl.CredentialProviderGCP,
634
"gs://.../...",
635
{"bearer_token": "initial"},
636
{"bearer_token": "updated"},
637
),
638
],
639
)
640
def test_credential_provider_rebuild_clears_cache(
641
credential_provider_class: type[CachingCredentialProvider],
642
scan_path: str,
643
initial_credentials: dict[str, str],
644
updated_credentials: dict[str, str],
645
plmonkeypatch: PlMonkeyPatch,
646
) -> None:
647
assert initial_credentials != updated_credentials
648
649
plmonkeypatch.setattr(
650
credential_provider_class,
651
"retrieve_credentials_impl",
652
lambda *_: (initial_credentials, None),
653
)
654
655
storage_options = (
656
{"aws_endpoint_url": "http://localhost:333"}
657
if credential_provider_class == pl.CredentialProviderAWS
658
else None
659
)
660
661
builder = _init_credential_provider_builder(
662
"auto",
663
scan_path,
664
storage_options=storage_options,
665
caller_name="test",
666
)
667
668
assert builder is not None
669
670
# This is a separate one for testing local to this function.
671
provider_local = credential_provider_class()
672
673
# Set the cache
674
provider_local()
675
676
# Now update the the retrieval function to return updated credentials.
677
plmonkeypatch.setattr(
678
credential_provider_class,
679
"retrieve_credentials_impl",
680
lambda *_: (updated_credentials, None),
681
)
682
683
# Despite "retrieve_credentials_impl" being updated, the providers should
684
# still return the initial credentials, as they were cached with an expiry
685
# of None.
686
assert provider_local() == (initial_credentials, None)
687
688
q = pl.scan_parquet(
689
scan_path,
690
storage_options=storage_options,
691
credential_provider="auto",
692
)
693
694
with pytest.raises(OSError):
695
q.collect()
696
697
provider_at_scan = builder.build_credential_provider()
698
699
assert provider_at_scan is not None
700
assert provider_at_scan() == (updated_credentials, None)
701
702
assert provider_local() == (initial_credentials, None)
703
704
provider_local.clear_cached_credentials()
705
706
assert provider_local() == (updated_credentials, None)
707
708
709
def test_user_gcp_token_provider(
710
plmonkeypatch: PlMonkeyPatch,
711
) -> None:
712
provider = UserProvidedGCPToken("A")
713
assert provider() == ({"bearer_token": "A"}, None)
714
plmonkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")
715
assert provider() == ({"bearer_token": "A"}, None)
716
717
718
def test_auto_init_cache_key_memoize(plmonkeypatch: PlMonkeyPatch) -> None:
719
get_cache_key_impl = Mock(wraps=AutoInit.get_cache_key_impl)
720
plmonkeypatch.setattr(
721
AutoInit,
722
"get_cache_key_impl",
723
lambda *a, **kw: get_cache_key_impl(*a, **kw),
724
)
725
726
v = AutoInit(int)
727
728
assert get_cache_key_impl.call_count == 0
729
730
v.get_or_init_cache_key()
731
assert get_cache_key_impl.call_count == 1
732
733
v.get_or_init_cache_key()
734
assert get_cache_key_impl.call_count == 1
735
736
737
def test_cached_credential_provider_returns_copied_creds() -> None:
738
provider_func = Mock(wraps=lambda: ({"A": "A"}, None))
739
provider = CachedCredentialProvider(provider_func)
740
741
assert provider_func.call_count == 0
742
743
provider()
744
assert provider() == ({"A": "A"}, None)
745
746
assert provider_func.call_count == 1
747
748
provider()[0]["B"] = "B"
749
750
assert provider() == ({"A": "A"}, None)
751
752
assert provider_func.call_count == 1
753
754
755
@pytest.mark.parametrize(
756
"partition_target",
757
[
758
pl.PartitionBy("s3://.../...", key=""),
759
],
760
)
761
def test_credential_provider_init_from_partition_target(
762
partition_target: pl.PartitionBy,
763
) -> None:
764
assert isinstance(
765
_init_credential_provider_builder(
766
"auto",
767
partition_target,
768
None,
769
"test",
770
),
771
CredentialProviderBuilder,
772
)
773
774
775
@pytest.mark.slow
776
def test_cache_user_credential_provider(plmonkeypatch: PlMonkeyPatch) -> None:
777
user_provider = Mock(
778
return_value=(
779
{"aws_access_key_id": "...", "aws_secret_access_key": "..."},
780
None,
781
)
782
)
783
784
def get_q() -> pl.LazyFrame:
785
return pl.scan_parquet(
786
"s3://.../...",
787
storage_options={"aws_endpoint_url": "http://localhost:333"},
788
credential_provider=user_provider,
789
)
790
791
assert user_provider.call_count == 0
792
793
with pytest.raises(OSError, match="http://localhost:333"):
794
get_q().collect()
795
796
assert user_provider.call_count == 2
797
798
with pytest.raises(OSError, match="http://localhost:333"):
799
get_q().collect()
800
801
assert user_provider.call_count == 3
802
803
plmonkeypatch.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")
804
805
with pytest.raises(OSError, match="http://localhost:333"):
806
get_q().collect()
807
808
assert user_provider.call_count == 5
809
810
811
@pytest.mark.slow
812
def test_credential_provider_global_config(plmonkeypatch: PlMonkeyPatch) -> None:
813
import polars as pl
814
import polars.io.cloud.credential_provider._builder
815
816
plmonkeypatch.setattr(
817
polars.io.cloud.credential_provider._builder,
818
"DEFAULT_CREDENTIAL_PROVIDER",
819
None,
820
)
821
822
provider = Mock(
823
return_value=(
824
{"aws_access_key_id": "...", "aws_secret_access_key": "..."},
825
None,
826
)
827
)
828
829
pl.Config.set_default_credential_provider(provider)
830
831
plmonkeypatch.setenv("AWS_ACCESS_KEY_ID", "...")
832
plmonkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "...")
833
834
def get_q() -> pl.LazyFrame:
835
return pl.scan_parquet(
836
"s3://.../...",
837
storage_options={"aws_endpoint_url": "http://localhost:333"},
838
)
839
840
def get_q_disable_cred_provider() -> pl.LazyFrame:
841
return pl.scan_parquet(
842
"s3://.../...",
843
credential_provider=None,
844
storage_options={"aws_endpoint_url": "http://localhost:333"},
845
)
846
847
assert provider.call_count == 0
848
849
with pytest.raises(OSError, match="http://localhost:333"):
850
get_q().collect()
851
852
assert provider.call_count == 2
853
854
with pytest.raises(OSError, match="http://localhost:333"):
855
get_q_disable_cred_provider().collect()
856
857
assert provider.call_count == 2
858
859
with pytest.raises(OSError, match="http://localhost:333"):
860
get_q().collect()
861
862
assert provider.call_count == 3
863
864
pl.Config.set_default_credential_provider("auto")
865
866
with pytest.raises(OSError, match="http://localhost:333"):
867
get_q().collect()
868
869
assert provider.call_count == 3
870
871
err_magic = "err_magic_3"
872
873
def raises(*_: None, **__: None) -> None:
874
raise AssertionError(err_magic)
875
876
plmonkeypatch.setattr(CredentialProviderBuilder, "__init__", raises)
877
878
with pytest.raises(AssertionError, match=err_magic):
879
get_q().collect()
880
881
pl.Config.set_default_credential_provider(None)
882
883
with pytest.raises(OSError, match="http://localhost:333"):
884
get_q().collect()
885
886