Path: blob/main/py-polars/tests/unit/io/cloud/test_credential_provider.py
6939 views
import io1import pickle2import sys3from datetime import datetime, timezone4from pathlib import Path5from typing import Any6from unittest.mock import Mock78import pytest910import polars as pl11import polars.io.cloud.credential_provider12from polars._typing import PartitioningScheme13from polars.io.cloud._utils import NoPickleOption14from polars.io.cloud.credential_provider._builder import (15AutoInit,16CredentialProviderBuilder,17_init_credential_provider_builder,18)19from polars.io.cloud.credential_provider._providers import (20CachedCredentialProvider,21CachingCredentialProvider,22UserProvidedGCPToken,23)242526@pytest.mark.parametrize(27"io_func",28[29*[pl.scan_parquet, pl.read_parquet],30pl.scan_csv,31*[pl.scan_ndjson, pl.read_ndjson],32pl.scan_ipc,33],34)35def test_credential_provider_scan(36io_func: Any, monkeypatch: pytest.MonkeyPatch37) -> None:38err_magic = "err_magic_3"3940def raises(*_: None, **__: None) -> None:41raise AssertionError(err_magic)4243from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder4445monkeypatch.setattr(CredentialProviderBuilder, "__init__", raises)4647with pytest.raises(AssertionError, match=err_magic):48io_func("s3://bucket/path", credential_provider="auto")4950with pytest.raises(AssertionError, match=err_magic):51io_func(52"s3://bucket/path",53credential_provider="auto",54storage_options={"aws_region": "eu-west-1"},55)5657# We can't test these with the `read_` functions as they end up executing58# the query59if io_func.__name__.startswith("scan_"):60# Passing `None` should disable the automatic instantiation of61# `CredentialProviderAWS`62io_func("s3://bucket/path", credential_provider=None)6364err_magic = "err_magic_7"6566def raises_2() -> pl.CredentialProviderFunctionReturn:67raise AssertionError(err_magic)6869with pytest.raises(AssertionError, match=err_magic):70io_func("s3://bucket/path", credential_provider=raises_2).collect()717273@pytest.mark.parametrize(74("provider_class", "path"),75[76(polars.io.cloud.credential_provider.CredentialProviderAWS, "s3://.../..."),77(polars.io.cloud.credential_provider.CredentialProviderGCP, "gs://.../..."),78(polars.io.cloud.credential_provider.CredentialProviderAzure, "az://.../..."),79],80)81def test_credential_provider_serialization_auto_init(82provider_class: polars.io.cloud.credential_provider.CredentialProvider,83path: str,84monkeypatch: pytest.MonkeyPatch,85) -> None:86def raises_1(*a: Any, **kw: Any) -> None:87msg = "err_magic_1"88raise AssertionError(msg)8990monkeypatch.setattr(provider_class, "__init__", raises_1)9192# If this is not set we will get an error before hitting the credential93# provider logic when polars attempts to retrieve the region from AWS.94monkeypatch.setenv("AWS_REGION", "eu-west-1")9596# Credential provider should not be initialized during query plan construction.97q = pl.scan_parquet(path)9899# Check baseline - query plan is configured to auto-initialize the credential100# provider.101with pytest.raises(AssertionError, match="err_magic_1"):102q.collect()103104q = pickle.loads(pickle.dumps(q))105106def raises_2(*a: Any, **kw: Any) -> None:107msg = "err_magic_2"108raise AssertionError(msg)109110monkeypatch.setattr(provider_class, "__init__", raises_2)111112# Check that auto-initialization happens upon executing the deserialized113# query.114with pytest.raises(AssertionError, match="err_magic_2"):115q.collect()116117118def test_credential_provider_serialization_custom_provider() -> None:119err_magic = "err_magic_3"120121class ErrCredentialProvider(pl.CredentialProvider):122def __call__(self) -> pl.CredentialProviderFunctionReturn:123raise AssertionError(err_magic)124125lf = pl.scan_parquet(126"s3://bucket/path", credential_provider=ErrCredentialProvider()127)128129serialized = lf.serialize()130131lf = pl.LazyFrame.deserialize(io.BytesIO(serialized))132133with pytest.raises(AssertionError, match=err_magic):134lf.collect()135136137def test_credential_provider_gcp_skips_config_autoload(138monkeypatch: pytest.MonkeyPatch,139) -> None:140monkeypatch.setenv("GOOGLE_SERVICE_ACCOUNT_PATH", "__non_existent")141142with pytest.raises(OSError, match="__non_existent"):143pl.scan_parquet("gs://.../...", credential_provider=None).collect()144145err_magic = "err_magic_3"146147def raises() -> pl.CredentialProviderFunctionReturn:148raise AssertionError(err_magic)149150with pytest.raises(AssertionError, match=err_magic):151pl.scan_parquet("gs://.../...", credential_provider=raises).collect()152153154def test_credential_provider_aws_import_error_with_requested_profile(155monkeypatch: pytest.MonkeyPatch,156) -> None:157def _session(self: Any) -> None:158msg = "err_magic_3"159raise ImportError(msg)160161monkeypatch.setattr(pl.CredentialProviderAWS, "_session", _session)162monkeypatch.setenv("AWS_REGION", "eu-west-1")163164q = pl.scan_parquet(165"s3://.../...",166credential_provider=pl.CredentialProviderAWS(profile_name="test_profile"),167)168169with pytest.raises(170pl.exceptions.ComputeError,171match=(172"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"173),174):175q.collect()176177q = pl.scan_parquet(178"s3://.../...",179storage_options={"aws_profile": "test_profile"},180)181182with pytest.raises(183pl.exceptions.ComputeError,184match=(185"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"186),187):188q.collect()189190191@pytest.mark.slow192@pytest.mark.write_disk193def test_credential_provider_aws_endpoint_url_scan_no_parameters(194tmp_path: Path,195monkeypatch: pytest.MonkeyPatch,196capfd: pytest.CaptureFixture[str],197) -> None:198tmp_path.mkdir(exist_ok=True)199200_set_default_credentials(tmp_path, monkeypatch)201cfg_file_path = tmp_path / "config"202203monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))204monkeypatch.setenv("POLARS_VERBOSE", "1")205206cfg_file_path.write_text("""\207[default]208endpoint_url = http://localhost:333209""")210211# Scan with no parameters should load via CredentialProviderAWS212q = pl.scan_parquet("s3://.../...")213214capfd.readouterr()215216with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):217q.collect()218219capture = capfd.readouterr().err220lines = capture.splitlines()221222assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines223224225@pytest.mark.slow226@pytest.mark.write_disk227def test_credential_provider_aws_endpoint_url_serde(228tmp_path: Path,229monkeypatch: pytest.MonkeyPatch,230capfd: pytest.CaptureFixture[str],231) -> None:232tmp_path.mkdir(exist_ok=True)233234_set_default_credentials(tmp_path, monkeypatch)235cfg_file_path = tmp_path / "config"236237monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))238monkeypatch.setenv("POLARS_VERBOSE", "1")239240cfg_file_path.write_text("""\241[default]242endpoint_url = http://localhost:333243""")244245q = pl.scan_parquet("s3://.../...")246q = pickle.loads(pickle.dumps(q))247248cfg_file_path.write_text("""\249[default]250endpoint_url = http://localhost:777251""")252253capfd.readouterr()254255with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):256q.collect()257258259@pytest.mark.slow260@pytest.mark.write_disk261def test_credential_provider_aws_endpoint_url_with_storage_options(262tmp_path: Path,263monkeypatch: pytest.MonkeyPatch,264capfd: pytest.CaptureFixture[str],265) -> None:266tmp_path.mkdir(exist_ok=True)267268_set_default_credentials(tmp_path, monkeypatch)269cfg_file_path = tmp_path / "config"270271monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))272monkeypatch.setenv("POLARS_VERBOSE", "1")273274cfg_file_path.write_text("""\275[default]276endpoint_url = http://localhost:333277""")278279# Previously we would not initialize a credential provider at all if secrets280# were given under `storage_options`. Now we always initialize so that we281# can load the `endpoint_url`, but we decide at the very last second whether282# to also retrieve secrets using the credential provider.283q = pl.scan_parquet(284"s3://.../...",285storage_options={286"aws_access_key_id": "...",287"aws_secret_access_key": "...",288},289)290291with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):292q.collect()293294capture = capfd.readouterr().err295lines = capture.splitlines()296297assert (298"[CredentialProviderAWS]: Will not be used as a provider: unhandled key "299"in storage_options: 'aws_secret_access_key'"300) in lines301assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines302303304@pytest.mark.parametrize(305"storage_options",306[307{"aws_endpoint_url": "http://localhost:777"},308{309"aws_access_key_id": "...",310"aws_secret_access_key": "...",311"aws_endpoint_url": "http://localhost:777",312},313],314)315@pytest.mark.slow316@pytest.mark.write_disk317def test_credential_provider_aws_endpoint_url_passed_in_storage_options(318storage_options: dict[str, str],319tmp_path: Path,320monkeypatch: pytest.MonkeyPatch,321) -> None:322tmp_path.mkdir(exist_ok=True)323324_set_default_credentials(tmp_path, monkeypatch)325cfg_file_path = tmp_path / "config"326monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))327328cfg_file_path.write_text("""\329[default]330endpoint_url = http://localhost:333331""")332333q = pl.scan_parquet("s3://.../...")334335with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):336q.collect()337338# An endpoint_url passed in `storage_options` should take precedence.339q = pl.scan_parquet(340"s3://.../...",341storage_options=storage_options,342)343344with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):345q.collect()346347348def _set_default_credentials(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:349creds_file_path = tmp_path / "credentials"350351creds_file_path.write_text("""\352[default]353aws_access_key_id=Z354aws_secret_access_key=Z355""")356357monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(creds_file_path))358359360@pytest.mark.slow361def test_credential_provider_python_builder_cache(362monkeypatch: pytest.MonkeyPatch,363capfd: pytest.CaptureFixture[str],364) -> None:365# Tests caching of building credential providers.366def dummy_static_aws_credentials(*a: Any, **kw: Any) -> Any:367return {368"aws_access_key_id": "...",369"aws_secret_access_key": "...",370}, None371372with monkeypatch.context() as cx:373provider_init = Mock(wraps=pl.CredentialProviderAWS.__init__)374375cx.setattr(376pl.CredentialProviderAWS,377"__init__",378lambda *a, **kw: provider_init(*a, **kw),379)380381cx.setattr(382pl.CredentialProviderAWS,383"retrieve_credentials_impl",384dummy_static_aws_credentials,385)386387# Ensure we are building a new query every time.388def get_q() -> pl.LazyFrame:389return pl.scan_parquet(390"s3://.../...",391storage_options={392"aws_profile": "A",393"aws_endpoint_url": "http://localhost",394},395credential_provider="auto",396)397398assert provider_init.call_count == 0399400with pytest.raises(OSError):401get_q().collect()402403assert provider_init.call_count == 1404405with pytest.raises(OSError):406get_q().collect()407408assert provider_init.call_count == 1409410with pytest.raises(OSError):411pl.scan_parquet(412"s3://.../...",413storage_options={414"aws_profile": "B",415"aws_endpoint_url": "http://localhost",416},417credential_provider="auto",418).collect()419420assert provider_init.call_count == 2421422with pytest.raises(OSError):423get_q().collect()424425assert provider_init.call_count == 2426427cx.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")428429with pytest.raises(OSError):430get_q().collect()431432# Note: Increments by 2 due to Rust-side object store rebuilding.433434assert provider_init.call_count == 4435436with pytest.raises(OSError):437get_q().collect()438439assert provider_init.call_count == 6440441with monkeypatch.context() as cx:442cx.setenv("POLARS_VERBOSE", "1")443builder = _init_credential_provider_builder(444"auto",445"s3://.../...",446None,447"test",448)449assert builder is not None450451capfd.readouterr()452453builder.build_credential_provider()454builder.build_credential_provider()455456capture = capfd.readouterr().err457458# Ensure cache key is memoized on generation459assert capture.count("AutoInit cache key") == 1460461pickle.loads(pickle.dumps(builder)).build_credential_provider()462463capture = capfd.readouterr().err464465# Ensure cache key is not serialized466assert capture.count("AutoInit cache key") == 1467468469@pytest.mark.slow470def test_credential_provider_python_credentials_cache(471monkeypatch: pytest.MonkeyPatch,472) -> None:473credentials_func = Mock(474wraps=lambda: (475{476"aws_access_key_id": "...",477"aws_secret_access_key": "...",478},479None,480)481)482483monkeypatch.setattr(484pl.CredentialProviderAWS,485"retrieve_credentials_impl",486credentials_func,487)488489assert credentials_func.call_count == 0490491provider = pl.CredentialProviderAWS()492493provider()494assert credentials_func.call_count == 1495496provider()497assert credentials_func.call_count == 1498499monkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")500501provider()502assert credentials_func.call_count == 2503504provider()505assert credentials_func.call_count == 3506507monkeypatch.delenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING")508509provider()510assert credentials_func.call_count == 4511512provider()513assert credentials_func.call_count == 4514515assert provider._cached_credentials.get() is not None516assert pickle.loads(pickle.dumps(provider))._cached_credentials.get() is None517518assert provider() == (519{520"aws_access_key_id": "...",521"aws_secret_access_key": "...",522},523None,524)525526provider()[0]["A"] = "A"527528assert provider() == (529{530"aws_access_key_id": "...",531"aws_secret_access_key": "...",532},533None,534)535536537def test_no_pickle_option() -> None:538v = NoPickleOption(3)539assert v.get() == 3540541out = pickle.loads(pickle.dumps(v))542543assert out.get() is None544545546@pytest.mark.write_disk547def test_credential_provider_aws_expiry(548tmp_path: Path, monkeypatch: pytest.MonkeyPatch549) -> None:550credential_file_path = tmp_path / "credentials.json"551552credential_file_path.write_text(553"""\554{555"Version": 1,556"AccessKeyId": "123",557"SecretAccessKey": "456",558"SessionToken": "789",559"Expiration": "2099-01-01T00:00:00+00:00"560}561"""562)563564cfg_file_path = tmp_path / "config"565566credential_file_path_str = str(credential_file_path).replace("\\", "/")567568cfg_file_path.write_text(f"""\569[profile cred_process]570credential_process = "{sys.executable}" -c "from pathlib import Path; print(Path('{credential_file_path_str}').read_text())"571""")572573monkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))574575creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()576577assert creds == {578"aws_access_key_id": "123",579"aws_secret_access_key": "456",580"aws_session_token": "789",581}582583assert expiry is not None584585assert datetime.fromtimestamp(expiry, tz=timezone.utc) == datetime.fromisoformat(586"2099-01-01T00:00:00+00:00"587)588589credential_file_path.write_text(590"""\591{592"Version": 1,593"AccessKeyId": "...",594"SecretAccessKey": "...",595"SessionToken": "..."596}597"""598)599600creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()601602assert creds == {603"aws_access_key_id": "...",604"aws_secret_access_key": "...",605"aws_session_token": "...",606}607608assert expiry is None609610611@pytest.mark.slow612@pytest.mark.parametrize(613(614"credential_provider_class",615"scan_path",616"initial_credentials",617"updated_credentials",618),619[620(621pl.CredentialProviderAWS,622"s3://.../...",623{"aws_access_key_id": "initial", "aws_secret_access_key": "initial"},624{"aws_access_key_id": "updated", "aws_secret_access_key": "updated"},625),626(627pl.CredentialProviderAzure,628"abfss://container@storage_account.dfs.core.windows.net/bucket",629{"bearer_token": "initial"},630{"bearer_token": "updated"},631),632(633pl.CredentialProviderGCP,634"gs://.../...",635{"bearer_token": "initial"},636{"bearer_token": "updated"},637),638],639)640def test_credential_provider_rebuild_clears_cache(641credential_provider_class: type[CachingCredentialProvider],642scan_path: str,643initial_credentials: dict[str, str],644updated_credentials: dict[str, str],645monkeypatch: pytest.MonkeyPatch,646) -> None:647assert initial_credentials != updated_credentials648649monkeypatch.setattr(650credential_provider_class,651"retrieve_credentials_impl",652lambda *_: (initial_credentials, None),653)654655storage_options = (656{"aws_endpoint_url": "http://localhost:333"}657if credential_provider_class == pl.CredentialProviderAWS658else None659)660661builder = _init_credential_provider_builder(662"auto",663scan_path,664storage_options=storage_options,665caller_name="test",666)667668assert builder is not None669670# This is a separate one for testing local to this function.671provider_local = credential_provider_class()672673# Set the cache674provider_local()675676# Now update the the retrieval function to return updated credentials.677monkeypatch.setattr(678credential_provider_class,679"retrieve_credentials_impl",680lambda *_: (updated_credentials, None),681)682683# Despite "retrieve_credentials_impl" being updated, the providers should684# still return the initial credentials, as they were cached with an expiry685# of None.686assert provider_local() == (initial_credentials, None)687688q = pl.scan_parquet(689scan_path,690storage_options=storage_options,691credential_provider="auto",692)693694with pytest.raises(OSError):695q.collect()696697provider_at_scan = builder.build_credential_provider()698699assert provider_at_scan is not None700assert provider_at_scan() == (updated_credentials, None)701702assert provider_local() == (initial_credentials, None)703704provider_local.clear_cached_credentials()705706assert provider_local() == (updated_credentials, None)707708709def test_user_gcp_token_provider(710monkeypatch: pytest.MonkeyPatch,711) -> None:712provider = UserProvidedGCPToken("A")713assert provider() == ({"bearer_token": "A"}, None)714monkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")715assert provider() == ({"bearer_token": "A"}, None)716717718def test_auto_init_cache_key_memoize(monkeypatch: pytest.MonkeyPatch) -> None:719get_cache_key_impl = Mock(wraps=AutoInit.get_cache_key_impl)720monkeypatch.setattr(721AutoInit,722"get_cache_key_impl",723lambda *a, **kw: get_cache_key_impl(*a, **kw),724)725726v = AutoInit(int)727728assert get_cache_key_impl.call_count == 0729730v.get_or_init_cache_key()731assert get_cache_key_impl.call_count == 1732733v.get_or_init_cache_key()734assert get_cache_key_impl.call_count == 1735736737def test_cached_credential_provider_returns_copied_creds() -> None:738provider_func = Mock(wraps=lambda: ({"A": "A"}, None))739provider = CachedCredentialProvider(provider_func)740741assert provider_func.call_count == 0742743provider()744assert provider() == ({"A": "A"}, None)745746assert provider_func.call_count == 1747748provider()[0]["B"] = "B"749750assert provider() == ({"A": "A"}, None)751752assert provider_func.call_count == 1753754755@pytest.mark.parametrize(756"partition_target",757[758pl.PartitionByKey("s3://.../...", by=""),759pl.PartitionMaxSize("s3://.../...", max_size=1),760pl.PartitionParted("s3://.../...", by=""),761],762)763def test_credential_provider_init_from_partition_target(764partition_target: PartitioningScheme,765) -> None:766assert isinstance(767_init_credential_provider_builder(768"auto",769partition_target,770None,771"test",772),773CredentialProviderBuilder,774)775776777@pytest.mark.slow778def test_cache_user_credential_provider(monkeypatch: pytest.MonkeyPatch) -> None:779user_provider = Mock(780return_value=(781{"aws_access_key_id": "...", "aws_secret_access_key": "..."},782None,783)784)785786def get_q() -> pl.LazyFrame:787return pl.scan_parquet(788"s3://.../...",789storage_options={"aws_endpoint_url": "http://localhost:333"},790credential_provider=user_provider,791)792793assert user_provider.call_count == 0794795with pytest.raises(OSError, match="http://localhost:333"):796get_q().collect()797798assert user_provider.call_count == 2799800with pytest.raises(OSError, match="http://localhost:333"):801get_q().collect()802803assert user_provider.call_count == 3804805monkeypatch.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")806807with pytest.raises(OSError, match="http://localhost:333"):808get_q().collect()809810assert user_provider.call_count == 5811812813