Path: blob/main/py-polars/tests/unit/io/cloud/test_credential_provider.py
8327 views
import io1import pickle2import sys3from datetime import datetime, timezone4from pathlib import Path5from typing import Any6from unittest.mock import Mock78import pytest910import polars as pl11import polars.io.cloud.credential_provider12from polars.io.cloud._utils import NoPickleOption13from polars.io.cloud.credential_provider._builder import (14AutoInit,15CredentialProviderBuilder,16_init_credential_provider_builder,17)18from polars.io.cloud.credential_provider._providers import (19CachedCredentialProvider,20CachingCredentialProvider,21UserProvidedGCPToken,22)232425@pytest.mark.parametrize(26"io_func",27[28*[pl.scan_parquet, pl.read_parquet],29pl.scan_csv,30*[pl.scan_ndjson, pl.read_ndjson],31pl.scan_ipc,32],33)34def test_credential_provider_scan(35io_func: Any, monkeypatch: pytest.MonkeyPatch36) -> None:37err_magic = "err_magic_3"3839def raises(*_: None, **__: None) -> None:40raise AssertionError(err_magic)4142from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder4344monkeypatch.setattr(CredentialProviderBuilder, "__init__", raises)4546with pytest.raises(AssertionError, match=err_magic):47io_func("s3://bucket/path", credential_provider="auto")4849with pytest.raises(AssertionError, match=err_magic):50io_func(51"s3://bucket/path",52credential_provider="auto",53storage_options={"aws_region": "eu-west-1"},54)5556# We can't test these with the `read_` functions as they end up executing57# the query58if io_func.__name__.startswith("scan_"):59# Passing `None` should disable the automatic instantiation of60# `CredentialProviderAWS`61io_func("s3://bucket/path", credential_provider=None)6263err_magic = "err_magic_7"6465def raises_2() -> pl.CredentialProviderFunctionReturn:66raise AssertionError(err_magic)6768with pytest.raises(AssertionError, match=err_magic):69io_func("s3://bucket/path", credential_provider=raises_2).collect()707172@pytest.mark.parametrize(73("provider_class", "path"),74[75(polars.io.cloud.credential_provider.CredentialProviderAWS, "s3://.../..."),76(polars.io.cloud.credential_provider.CredentialProviderGCP, "gs://.../..."),77(polars.io.cloud.credential_provider.CredentialProviderAzure, "az://.../..."),78],79)80def test_credential_provider_serialization_auto_init(81provider_class: polars.io.cloud.credential_provider.CredentialProvider,82path: str,83monkeypatch: pytest.MonkeyPatch,84) -> None:85def raises_1(*a: Any, **kw: Any) -> None:86msg = "err_magic_1"87raise AssertionError(msg)8889monkeypatch.setattr(provider_class, "__init__", raises_1)9091# If this is not set we will get an error before hitting the credential92# provider logic when polars attempts to retrieve the region from AWS.93monkeypatch.setenv("AWS_REGION", "eu-west-1")9495# Credential provider should not be initialized during query plan construction.96q = pl.scan_parquet(path)9798# Check baseline - query plan is configured to auto-initialize the credential99# provider.100with pytest.raises(AssertionError, match="err_magic_1"):101q.collect()102103q = pickle.loads(pickle.dumps(q))104105def raises_2(*a: Any, **kw: Any) -> None:106msg = "err_magic_2"107raise AssertionError(msg)108109monkeypatch.setattr(provider_class, "__init__", raises_2)110111# Check that auto-initialization happens upon executing the deserialized112# query.113with pytest.raises(AssertionError, match="err_magic_2"):114q.collect()115116117@pytest.mark.slow118def test_credential_provider_serialization_custom_provider() -> None:119err_magic = "err_magic_3"120121class ErrCredentialProvider(pl.CredentialProvider):122def __call__(self) -> pl.CredentialProviderFunctionReturn:123raise AssertionError(err_magic)124125lf = pl.scan_parquet(126"s3://bucket/path", credential_provider=ErrCredentialProvider()127)128129serialized = lf.serialize()130131lf = pl.LazyFrame.deserialize(io.BytesIO(serialized))132133with pytest.raises(AssertionError, match=err_magic):134lf.collect()135136137def test_credential_provider_gcp_skips_config_autoload(138monkeypatch: pytest.MonkeyPatch,139) -> None:140monkeypatch.setenv("GOOGLE_SERVICE_ACCOUNT_PATH", "__non_existent")141142with pytest.raises(OSError, match="__non_existent"):143pl.scan_parquet("gs://.../...", credential_provider=None).collect()144145err_magic = "err_magic_3"146147def raises() -> pl.CredentialProviderFunctionReturn:148raise AssertionError(err_magic)149150with pytest.raises(AssertionError, match=err_magic):151pl.scan_parquet("gs://.../...", credential_provider=raises).collect()152153154def test_credential_provider_aws_import_error_with_requested_profile(155monkeypatch: pytest.MonkeyPatch,156) -> None:157def _session(self: Any) -> None:158msg = "err_magic_3"159raise ImportError(msg)160161monkeypatch.setattr(pl.CredentialProviderAWS, "_session", _session)162monkeypatch.setenv("AWS_REGION", "eu-west-1")163164q = pl.scan_parquet(165"s3://.../...",166credential_provider=pl.CredentialProviderAWS(profile_name="test_profile"),167)168169with pytest.raises(170pl.exceptions.ComputeError,171match=(172"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"173),174):175q.collect()176177q = pl.scan_parquet(178"s3://.../...",179storage_options={"aws_profile": "test_profile"},180)181182with pytest.raises(183pl.exceptions.ComputeError,184match=(185"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"186),187):188q.collect()189190191@pytest.mark.slow192@pytest.mark.write_disk193def test_credential_provider_aws_endpoint_url_scan_no_parameters(194tmp_path: Path,195monkeypatch: pytest.MonkeyPatch,196capfd: pytest.CaptureFixture[str],197) -> None:198tmp_path.mkdir(exist_ok=True)199200_set_default_credentials(tmp_path, monkeypatch)201cfg_file_path = tmp_path / "config"202203monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))204monkeypatch.setenv("POLARS_VERBOSE", "1")205206cfg_file_path.write_text("""\207[default]208endpoint_url = http://localhost:333209""")210211# Scan with no parameters should load via CredentialProviderAWS212q = pl.scan_parquet("s3://.../...")213214capfd.readouterr()215216with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):217q.collect()218219capture = capfd.readouterr().err220lines = capture.splitlines()221222assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines223224225@pytest.mark.slow226@pytest.mark.write_disk227def test_credential_provider_aws_endpoint_url_serde(228tmp_path: Path,229monkeypatch: pytest.MonkeyPatch,230capfd: pytest.CaptureFixture[str],231) -> None:232tmp_path.mkdir(exist_ok=True)233234_set_default_credentials(tmp_path, monkeypatch)235cfg_file_path = tmp_path / "config"236237monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))238monkeypatch.setenv("POLARS_VERBOSE", "1")239240cfg_file_path.write_text("""\241[default]242endpoint_url = http://localhost:333243""")244245q = pl.scan_parquet("s3://.../...")246q = pickle.loads(pickle.dumps(q))247248cfg_file_path.write_text("""\249[default]250endpoint_url = http://localhost:777251""")252253capfd.readouterr()254255with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):256q.collect()257258259@pytest.mark.slow260@pytest.mark.write_disk261def test_credential_provider_aws_endpoint_url_with_storage_options(262tmp_path: Path,263monkeypatch: pytest.MonkeyPatch,264capfd: pytest.CaptureFixture[str],265) -> None:266tmp_path.mkdir(exist_ok=True)267268_set_default_credentials(tmp_path, monkeypatch)269cfg_file_path = tmp_path / "config"270271monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))272monkeypatch.setenv("POLARS_VERBOSE", "1")273274cfg_file_path.write_text("""\275[default]276endpoint_url = http://localhost:333277""")278279# Previously we would not initialize a credential provider at all if secrets280# were given under `storage_options`. Now we always initialize so that we281# can load the `endpoint_url`, but we decide at the very last second whether282# to also retrieve secrets using the credential provider.283q = pl.scan_parquet(284"s3://.../...",285storage_options={286"aws_access_key_id": "...",287"aws_secret_access_key": "...",288},289)290291with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):292q.collect()293294capture = capfd.readouterr().err295lines = capture.splitlines()296297assert (298"[CredentialProviderAWS]: Will not be used as a provider: unhandled key "299"in storage_options: 'aws_secret_access_key'"300) in lines301assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines302303304@pytest.mark.parametrize(305"storage_options",306[307{"aws_endpoint_url": "http://localhost:777"},308{309"aws_access_key_id": "...",310"aws_secret_access_key": "...",311"aws_endpoint_url": "http://localhost:777",312},313],314)315@pytest.mark.slow316@pytest.mark.write_disk317def test_credential_provider_aws_endpoint_url_passed_in_storage_options(318storage_options: dict[str, str],319tmp_path: Path,320monkeypatch: pytest.MonkeyPatch,321) -> None:322tmp_path.mkdir(exist_ok=True)323324_set_default_credentials(tmp_path, monkeypatch)325cfg_file_path = tmp_path / "config"326monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))327328cfg_file_path.write_text("""\329[default]330endpoint_url = http://localhost:333331""")332333q = pl.scan_parquet("s3://.../...")334335with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):336q.collect()337338# An endpoint_url passed in `storage_options` should take precedence.339q = pl.scan_parquet(340"s3://.../...",341storage_options=storage_options,342)343344with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):345q.collect()346347348def _set_default_credentials(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:349creds_file_path = tmp_path / "credentials"350351creds_file_path.write_text("""\352[default]353aws_access_key_id=Z354aws_secret_access_key=Z355""")356357monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(creds_file_path))358359360@pytest.mark.slow361def test_credential_provider_python_builder_cache(362monkeypatch: pytest.MonkeyPatch,363capfd: pytest.CaptureFixture[str],364) -> None:365# Tests caching of building credential providers.366def dummy_static_aws_credentials(*a: Any, **kw: Any) -> Any:367return {368"aws_access_key_id": "...",369"aws_secret_access_key": "...",370}, None371372with monkeypatch.context() as cx:373provider_init = Mock(wraps=pl.CredentialProviderAWS.__init__)374375cx.setattr(376pl.CredentialProviderAWS,377"__init__",378lambda *a, **kw: provider_init(*a, **kw),379)380381cx.setattr(382pl.CredentialProviderAWS,383"retrieve_credentials_impl",384dummy_static_aws_credentials,385)386387# Ensure we are building a new query every time.388def get_q() -> pl.LazyFrame:389return pl.scan_parquet(390"s3://.../...",391storage_options={392"aws_profile": "A",393"aws_endpoint_url": "http://localhost",394},395credential_provider="auto",396)397398assert provider_init.call_count == 0399400with pytest.raises(OSError):401get_q().collect()402403assert provider_init.call_count == 1404405with pytest.raises(OSError):406get_q().collect()407408assert provider_init.call_count == 1409410with pytest.raises(OSError):411pl.scan_parquet(412"s3://.../...",413storage_options={414"aws_profile": "B",415"aws_endpoint_url": "http://localhost",416},417credential_provider="auto",418).collect()419420assert provider_init.call_count == 2421422with pytest.raises(OSError):423get_q().collect()424425assert provider_init.call_count == 2426427cx.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")428429with pytest.raises(OSError):430get_q().collect()431432# Note: Increments by 2 due to Rust-side object store rebuilding.433434assert provider_init.call_count == 4435436with pytest.raises(OSError):437get_q().collect()438439assert provider_init.call_count == 6440441with monkeypatch.context() as cx:442cx.setenv("POLARS_VERBOSE", "1")443builder = _init_credential_provider_builder(444"auto",445"s3://.../...",446None,447"test",448)449assert builder is not None450451capfd.readouterr()452453builder.build_credential_provider()454builder.build_credential_provider()455456capture = capfd.readouterr().err457458# Ensure cache key is memoized on generation459assert capture.count("AutoInit cache key") == 1460461pickle.loads(pickle.dumps(builder)).build_credential_provider()462463capture = capfd.readouterr().err464465# Ensure cache key is not serialized466assert capture.count("AutoInit cache key") == 1467468469@pytest.mark.slow470def test_credential_provider_python_credentials_cache(471monkeypatch: pytest.MonkeyPatch,472) -> None:473credentials_func = Mock(474wraps=lambda: (475{476"aws_access_key_id": "...",477"aws_secret_access_key": "...",478},479None,480)481)482483monkeypatch.setattr(484pl.CredentialProviderAWS,485"retrieve_credentials_impl",486credentials_func,487)488489assert credentials_func.call_count == 0490491provider = pl.CredentialProviderAWS()492493provider()494assert credentials_func.call_count == 1495496provider()497assert credentials_func.call_count == 1498499monkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")500501provider()502assert credentials_func.call_count == 2503504provider()505assert credentials_func.call_count == 3506507monkeypatch.delenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING")508509provider()510assert credentials_func.call_count == 4511512provider()513assert credentials_func.call_count == 4514515assert provider._cached_credentials.get() is not None516assert pickle.loads(pickle.dumps(provider))._cached_credentials.get() is None517518assert provider() == (519{520"aws_access_key_id": "...",521"aws_secret_access_key": "...",522},523None,524)525526provider()[0]["A"] = "A"527528assert provider() == (529{530"aws_access_key_id": "...",531"aws_secret_access_key": "...",532},533None,534)535536537def test_no_pickle_option() -> None:538v = NoPickleOption(3)539assert v.get() == 3540541out = pickle.loads(pickle.dumps(v))542543assert out.get() is None544545546@pytest.mark.write_disk547def test_credential_provider_aws_expiry(548tmp_path: Path, monkeypatch: pytest.MonkeyPatch549) -> None:550credential_file_path = tmp_path / "credentials.json"551552credential_file_path.write_text(553"""\554{555"Version": 1,556"AccessKeyId": "123",557"SecretAccessKey": "456",558"SessionToken": "789",559"Expiration": "2099-01-01T00:00:00+00:00"560}561"""562)563564cfg_file_path = tmp_path / "config"565566credential_file_path_str = str(credential_file_path).replace("\\", "/")567568cfg_file_path.write_text(f"""\569[profile cred_process]570credential_process = "{sys.executable}" -c "from pathlib import Path; print(Path('{credential_file_path_str}').read_text())"571""")572573monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))574575creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()576577assert creds == {578"aws_access_key_id": "123",579"aws_secret_access_key": "456",580"aws_session_token": "789",581}582583assert expiry is not None584585assert datetime.fromtimestamp(expiry, tz=timezone.utc) == datetime.fromisoformat(586"2099-01-01T00:00:00+00:00"587)588589credential_file_path.write_text(590"""\591{592"Version": 1,593"AccessKeyId": "...",594"SecretAccessKey": "...",595"SessionToken": "..."596}597"""598)599600creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()601602assert creds == {603"aws_access_key_id": "...",604"aws_secret_access_key": "...",605"aws_session_token": "...",606}607608assert expiry is None609610611@pytest.mark.slow612@pytest.mark.parametrize(613(614"credential_provider_class",615"scan_path",616"initial_credentials",617"updated_credentials",618),619[620(621pl.CredentialProviderAWS,622"s3://.../...",623{"aws_access_key_id": "initial", "aws_secret_access_key": "initial"},624{"aws_access_key_id": "updated", "aws_secret_access_key": "updated"},625),626(627pl.CredentialProviderAzure,628"abfss://container@storage_account.dfs.core.windows.net/bucket",629{"bearer_token": "initial"},630{"bearer_token": "updated"},631),632(633pl.CredentialProviderGCP,634"gs://.../...",635{"bearer_token": "initial"},636{"bearer_token": "updated"},637),638],639)640def test_credential_provider_rebuild_clears_cache(641credential_provider_class: type[CachingCredentialProvider],642scan_path: str,643initial_credentials: dict[str, str],644updated_credentials: dict[str, str],645monkeypatch: pytest.MonkeyPatch,646) -> None:647assert initial_credentials != updated_credentials648649monkeypatch.setattr(650credential_provider_class,651"retrieve_credentials_impl",652lambda *_: (initial_credentials, None),653)654655storage_options = (656{"aws_endpoint_url": "http://localhost:333"}657if credential_provider_class == pl.CredentialProviderAWS658else None659)660661builder = _init_credential_provider_builder(662"auto",663scan_path,664storage_options=storage_options,665caller_name="test",666)667668assert builder is not None669670# This is a separate one for testing local to this function.671provider_local = credential_provider_class()672673# Set the cache674provider_local()675676# Now update the the retrieval function to return updated credentials.677monkeypatch.setattr(678credential_provider_class,679"retrieve_credentials_impl",680lambda *_: (updated_credentials, None),681)682683# Despite "retrieve_credentials_impl" being updated, the providers should684# still return the initial credentials, as they were cached with an expiry685# of None.686assert provider_local() == (initial_credentials, None)687688q = pl.scan_parquet(689scan_path,690storage_options=storage_options,691credential_provider="auto",692)693694with pytest.raises(OSError):695q.collect()696697provider_at_scan = builder.build_credential_provider()698699assert provider_at_scan is not None700assert provider_at_scan() == (updated_credentials, None)701702assert provider_local() == (initial_credentials, None)703704provider_local.clear_cached_credentials()705706assert provider_local() == (updated_credentials, None)707708709def test_user_gcp_token_provider(710monkeypatch: pytest.MonkeyPatch,711) -> None:712provider = UserProvidedGCPToken("A")713assert provider() == ({"bearer_token": "A"}, None)714monkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")715assert provider() == ({"bearer_token": "A"}, None)716717718def test_auto_init_cache_key_memoize(monkeypatch: pytest.MonkeyPatch) -> None:719get_cache_key_impl = Mock(wraps=AutoInit.get_cache_key_impl)720monkeypatch.setattr(721AutoInit,722"get_cache_key_impl",723lambda *a, **kw: get_cache_key_impl(*a, **kw),724)725726v = AutoInit(int)727728assert get_cache_key_impl.call_count == 0729730v.get_or_init_cache_key()731assert get_cache_key_impl.call_count == 1732733v.get_or_init_cache_key()734assert get_cache_key_impl.call_count == 1735736737def test_cached_credential_provider_returns_copied_creds() -> None:738provider_func = Mock(wraps=lambda: ({"A": "A"}, None))739provider = CachedCredentialProvider(provider_func)740741assert provider_func.call_count == 0742743provider()744assert provider() == ({"A": "A"}, None)745746assert provider_func.call_count == 1747748provider()[0]["B"] = "B"749750assert provider() == ({"A": "A"}, None)751752assert provider_func.call_count == 1753754755@pytest.mark.parametrize(756"partition_target",757[758pl.PartitionBy("s3://.../...", key=""),759],760)761def test_credential_provider_init_from_partition_target(762partition_target: pl.PartitionBy,763) -> None:764assert isinstance(765_init_credential_provider_builder(766"auto",767partition_target,768None,769"test",770),771CredentialProviderBuilder,772)773774775@pytest.mark.slow776def test_cache_user_credential_provider(monkeypatch: pytest.MonkeyPatch) -> None:777user_provider = Mock(778return_value=(779{"aws_access_key_id": "...", "aws_secret_access_key": "..."},780None,781)782)783784def get_q() -> pl.LazyFrame:785return pl.scan_parquet(786"s3://.../...",787storage_options={"aws_endpoint_url": "http://localhost:333"},788credential_provider=user_provider,789)790791assert user_provider.call_count == 0792793with pytest.raises(OSError, match="http://localhost:333"):794get_q().collect()795796assert user_provider.call_count == 2797798with pytest.raises(OSError, match="http://localhost:333"):799get_q().collect()800801assert user_provider.call_count == 3802803monkeypatch.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")804805with pytest.raises(OSError, match="http://localhost:333"):806get_q().collect()807808assert user_provider.call_count == 5809810811@pytest.mark.slow812def test_credential_provider_global_config(monkeypatch: pytest.MonkeyPatch) -> None:813import polars as pl814import polars.io.cloud.credential_provider._builder815816monkeypatch.setattr(817polars.io.cloud.credential_provider._builder,818"DEFAULT_CREDENTIAL_PROVIDER",819None,820)821822provider = Mock(823return_value=(824{"aws_access_key_id": "...", "aws_secret_access_key": "..."},825None,826)827)828829pl.Config.set_default_credential_provider(provider)830831monkeypatch.setenv("AWS_ACCESS_KEY_ID", "...")832monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "...")833834def get_q() -> pl.LazyFrame:835return pl.scan_parquet(836"s3://.../...",837storage_options={"aws_endpoint_url": "http://localhost:333"},838)839840def get_q_disable_cred_provider() -> pl.LazyFrame:841return pl.scan_parquet(842"s3://.../...",843credential_provider=None,844storage_options={"aws_endpoint_url": "http://localhost:333"},845)846847assert provider.call_count == 0848849with pytest.raises(OSError, match="http://localhost:333"):850get_q().collect()851852assert provider.call_count == 2853854with pytest.raises(OSError, match="http://localhost:333"):855get_q_disable_cred_provider().collect()856857assert provider.call_count == 2858859with pytest.raises(OSError, match="http://localhost:333"):860get_q().collect()861862assert provider.call_count == 3863864pl.Config.set_default_credential_provider("auto")865866with pytest.raises(OSError, match="http://localhost:333"):867get_q().collect()868869assert provider.call_count == 3870871err_magic = "err_magic_3"872873def raises(*_: None, **__: None) -> None:874raise AssertionError(err_magic)875876monkeypatch.setattr(CredentialProviderBuilder, "__init__", raises)877878with pytest.raises(AssertionError, match=err_magic):879get_q().collect()880881pl.Config.set_default_credential_provider(None)882883with pytest.raises(OSError, match="http://localhost:333"):884get_q().collect()885886887