Path: blob/main/py-polars/tests/unit/io/cloud/test_credential_provider.py
8424 views
import io1import pickle2import sys3from datetime import datetime, timezone4from pathlib import Path5from typing import Any6from unittest.mock import Mock78import pytest910import polars as pl11import polars.io.cloud.credential_provider12from polars.io.cloud._utils import NoPickleOption13from polars.io.cloud.credential_provider._builder import (14AutoInit,15CredentialProviderBuilder,16_init_credential_provider_builder,17)18from polars.io.cloud.credential_provider._providers import (19CachedCredentialProvider,20CachingCredentialProvider,21UserProvidedGCPToken,22)23from tests.conftest import PlMonkeyPatch242526@pytest.mark.parametrize(27"io_func",28[29*[pl.scan_parquet, pl.read_parquet],30pl.scan_csv,31*[pl.scan_ndjson, pl.read_ndjson],32pl.scan_ipc,33],34)35def test_credential_provider_scan(io_func: Any, plmonkeypatch: PlMonkeyPatch) -> None:36err_magic = "err_magic_3"3738def raises(*_: None, **__: None) -> None:39raise AssertionError(err_magic)4041from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder4243plmonkeypatch.setattr(CredentialProviderBuilder, "__init__", raises)4445with pytest.raises(AssertionError, match=err_magic):46io_func("s3://bucket/path", credential_provider="auto")4748with pytest.raises(AssertionError, match=err_magic):49io_func(50"s3://bucket/path",51credential_provider="auto",52storage_options={"aws_region": "eu-west-1"},53)5455# We can't test these with the `read_` functions as they end up executing56# the query57if io_func.__name__.startswith("scan_"):58# Passing `None` should disable the automatic instantiation of59# `CredentialProviderAWS`60io_func("s3://bucket/path", credential_provider=None)6162err_magic = "err_magic_7"6364def raises_2() -> pl.CredentialProviderFunctionReturn:65raise AssertionError(err_magic)6667with pytest.raises(AssertionError, match=err_magic):68io_func("s3://bucket/path", credential_provider=raises_2).collect()697071@pytest.mark.parametrize(72("provider_class", "path"),73[74(polars.io.cloud.credential_provider.CredentialProviderAWS, "s3://.../..."),75(polars.io.cloud.credential_provider.CredentialProviderGCP, "gs://.../..."),76(polars.io.cloud.credential_provider.CredentialProviderAzure, "az://.../..."),77],78)79def test_credential_provider_serialization_auto_init(80provider_class: polars.io.cloud.credential_provider.CredentialProvider,81path: str,82plmonkeypatch: PlMonkeyPatch,83) -> None:84def raises_1(*a: Any, **kw: Any) -> None:85msg = "err_magic_1"86raise AssertionError(msg)8788plmonkeypatch.setattr(provider_class, "__init__", raises_1)8990# If this is not set we will get an error before hitting the credential91# provider logic when polars attempts to retrieve the region from AWS.92plmonkeypatch.setenv("AWS_REGION", "eu-west-1")9394# Credential provider should not be initialized during query plan construction.95q = pl.scan_parquet(path)9697# Check baseline - query plan is configured to auto-initialize the credential98# provider.99with pytest.raises(AssertionError, match="err_magic_1"):100q.collect()101102q = pickle.loads(pickle.dumps(q))103104def raises_2(*a: Any, **kw: Any) -> None:105msg = "err_magic_2"106raise AssertionError(msg)107108plmonkeypatch.setattr(provider_class, "__init__", raises_2)109110# Check that auto-initialization happens upon executing the deserialized111# query.112with pytest.raises(AssertionError, match="err_magic_2"):113q.collect()114115116@pytest.mark.slow117def test_credential_provider_serialization_custom_provider() -> None:118err_magic = "err_magic_3"119120class ErrCredentialProvider(pl.CredentialProvider):121def __call__(self) -> pl.CredentialProviderFunctionReturn:122raise AssertionError(err_magic)123124lf = pl.scan_parquet(125"s3://bucket/path", credential_provider=ErrCredentialProvider()126)127128serialized = lf.serialize()129130lf = pl.LazyFrame.deserialize(io.BytesIO(serialized))131132with pytest.raises(AssertionError, match=err_magic):133lf.collect()134135136def test_credential_provider_gcp_skips_config_autoload(137plmonkeypatch: PlMonkeyPatch,138) -> None:139plmonkeypatch.setenv("GOOGLE_SERVICE_ACCOUNT_PATH", "__non_existent")140141with pytest.raises(OSError, match="__non_existent"):142pl.scan_parquet("gs://.../...", credential_provider=None).collect()143144err_magic = "err_magic_3"145146def raises() -> pl.CredentialProviderFunctionReturn:147raise AssertionError(err_magic)148149with pytest.raises(AssertionError, match=err_magic):150pl.scan_parquet("gs://.../...", credential_provider=raises).collect()151152153def test_credential_provider_aws_import_error_with_requested_profile(154plmonkeypatch: PlMonkeyPatch,155) -> None:156def _session(self: Any) -> None:157msg = "err_magic_3"158raise ImportError(msg)159160plmonkeypatch.setattr(pl.CredentialProviderAWS, "_session", _session)161plmonkeypatch.setenv("AWS_REGION", "eu-west-1")162163q = pl.scan_parquet(164"s3://.../...",165credential_provider=pl.CredentialProviderAWS(profile_name="test_profile"),166)167168with pytest.raises(169pl.exceptions.ComputeError,170match=(171"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"172),173):174q.collect()175176q = pl.scan_parquet(177"s3://.../...",178storage_options={"aws_profile": "test_profile"},179)180181with pytest.raises(182pl.exceptions.ComputeError,183match=(184"cannot load requested aws_profile 'test_profile': ImportError: err_magic_3"185),186):187q.collect()188189190@pytest.mark.slow191@pytest.mark.write_disk192def test_credential_provider_aws_endpoint_url_scan_no_parameters(193tmp_path: Path,194plmonkeypatch: PlMonkeyPatch,195capfd: pytest.CaptureFixture[str],196) -> None:197tmp_path.mkdir(exist_ok=True)198199_set_default_credentials(tmp_path, plmonkeypatch)200cfg_file_path = tmp_path / "config"201202plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))203plmonkeypatch.setenv("POLARS_VERBOSE", "1")204205cfg_file_path.write_text("""\206[default]207endpoint_url = http://localhost:333208""")209210# Scan with no parameters should load via CredentialProviderAWS211q = pl.scan_parquet("s3://.../...")212213capfd.readouterr()214215with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):216q.collect()217218capture = capfd.readouterr().err219lines = capture.splitlines()220221assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines222223224@pytest.mark.slow225@pytest.mark.write_disk226def test_credential_provider_aws_endpoint_url_serde(227tmp_path: Path,228plmonkeypatch: PlMonkeyPatch,229capfd: pytest.CaptureFixture[str],230) -> None:231tmp_path.mkdir(exist_ok=True)232233_set_default_credentials(tmp_path, plmonkeypatch)234cfg_file_path = tmp_path / "config"235236plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))237plmonkeypatch.setenv("POLARS_VERBOSE", "1")238239cfg_file_path.write_text("""\240[default]241endpoint_url = http://localhost:333242""")243244q = pl.scan_parquet("s3://.../...")245q = pickle.loads(pickle.dumps(q))246247cfg_file_path.write_text("""\248[default]249endpoint_url = http://localhost:777250""")251252capfd.readouterr()253254with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):255q.collect()256257258@pytest.mark.slow259@pytest.mark.write_disk260def test_credential_provider_aws_endpoint_url_with_storage_options(261tmp_path: Path,262plmonkeypatch: PlMonkeyPatch,263capfd: pytest.CaptureFixture[str],264) -> None:265tmp_path.mkdir(exist_ok=True)266267_set_default_credentials(tmp_path, plmonkeypatch)268cfg_file_path = tmp_path / "config"269270plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))271plmonkeypatch.setenv("POLARS_VERBOSE", "1")272273cfg_file_path.write_text("""\274[default]275endpoint_url = http://localhost:333276""")277278# Previously we would not initialize a credential provider at all if secrets279# were given under `storage_options`. Now we always initialize so that we280# can load the `endpoint_url`, but we decide at the very last second whether281# to also retrieve secrets using the credential provider.282q = pl.scan_parquet(283"s3://.../...",284storage_options={285"aws_access_key_id": "...",286"aws_secret_access_key": "...",287},288)289290with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):291q.collect()292293capture = capfd.readouterr().err294lines = capture.splitlines()295296assert (297"[CredentialProviderAWS]: Will not be used as a provider: unhandled key "298"in storage_options: 'aws_secret_access_key'"299) in lines300assert "[CredentialProviderAWS]: Loaded endpoint_url: http://localhost:333" in lines301302303@pytest.mark.parametrize(304"storage_options",305[306{"aws_endpoint_url": "http://localhost:777"},307{308"aws_access_key_id": "...",309"aws_secret_access_key": "...",310"aws_endpoint_url": "http://localhost:777",311},312],313)314@pytest.mark.slow315@pytest.mark.write_disk316def test_credential_provider_aws_endpoint_url_passed_in_storage_options(317storage_options: dict[str, str],318tmp_path: Path,319plmonkeypatch: PlMonkeyPatch,320) -> None:321tmp_path.mkdir(exist_ok=True)322323_set_default_credentials(tmp_path, plmonkeypatch)324cfg_file_path = tmp_path / "config"325plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))326327cfg_file_path.write_text("""\328[default]329endpoint_url = http://localhost:333330""")331332q = pl.scan_parquet("s3://.../...")333334with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:333"):335q.collect()336337# An endpoint_url passed in `storage_options` should take precedence.338q = pl.scan_parquet(339"s3://.../...",340storage_options=storage_options,341)342343with pytest.raises(IOError, match=r"Error performing HEAD http://localhost:777"):344q.collect()345346347def _set_default_credentials(tmp_path: Path, plmonkeypatch: PlMonkeyPatch) -> None:348creds_file_path = tmp_path / "credentials"349350creds_file_path.write_text("""\351[default]352aws_access_key_id=Z353aws_secret_access_key=Z354""")355356plmonkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(creds_file_path))357358359@pytest.mark.slow360def test_credential_provider_python_builder_cache(361plmonkeypatch: PlMonkeyPatch,362capfd: pytest.CaptureFixture[str],363) -> None:364# Tests caching of building credential providers.365def dummy_static_aws_credentials(*a: Any, **kw: Any) -> Any:366return {367"aws_access_key_id": "...",368"aws_secret_access_key": "...",369}, None370371with plmonkeypatch.context() as cx:372provider_init = Mock(wraps=pl.CredentialProviderAWS.__init__)373374cx.setattr(375pl.CredentialProviderAWS,376"__init__",377lambda *a, **kw: provider_init(*a, **kw),378)379380cx.setattr(381pl.CredentialProviderAWS,382"retrieve_credentials_impl",383dummy_static_aws_credentials,384)385386# Ensure we are building a new query every time.387def get_q() -> pl.LazyFrame:388return pl.scan_parquet(389"s3://.../...",390storage_options={391"aws_profile": "A",392"aws_endpoint_url": "http://localhost",393},394credential_provider="auto",395)396397assert provider_init.call_count == 0398399with pytest.raises(OSError):400get_q().collect()401402assert provider_init.call_count == 1403404with pytest.raises(OSError):405get_q().collect()406407assert provider_init.call_count == 1408409with pytest.raises(OSError):410pl.scan_parquet(411"s3://.../...",412storage_options={413"aws_profile": "B",414"aws_endpoint_url": "http://localhost",415},416credential_provider="auto",417).collect()418419assert provider_init.call_count == 2420421with pytest.raises(OSError):422get_q().collect()423424assert provider_init.call_count == 2425426cx.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")427428with pytest.raises(OSError):429get_q().collect()430431# Note: Increments by 2 due to Rust-side object store rebuilding.432433assert provider_init.call_count == 4434435with pytest.raises(OSError):436get_q().collect()437438assert provider_init.call_count == 6439440with plmonkeypatch.context() as cx:441cx.setenv("POLARS_VERBOSE", "1")442builder = _init_credential_provider_builder(443"auto",444"s3://.../...",445None,446"test",447)448assert builder is not None449450capfd.readouterr()451452builder.build_credential_provider()453builder.build_credential_provider()454455capture = capfd.readouterr().err456457# Ensure cache key is memoized on generation458assert capture.count("AutoInit cache key") == 1459460pickle.loads(pickle.dumps(builder)).build_credential_provider()461462capture = capfd.readouterr().err463464# Ensure cache key is not serialized465assert capture.count("AutoInit cache key") == 1466467468@pytest.mark.slow469def test_credential_provider_python_credentials_cache(470plmonkeypatch: PlMonkeyPatch,471) -> None:472credentials_func = Mock(473wraps=lambda: (474{475"aws_access_key_id": "...",476"aws_secret_access_key": "...",477},478None,479)480)481482plmonkeypatch.setattr(483pl.CredentialProviderAWS,484"retrieve_credentials_impl",485credentials_func,486)487488assert credentials_func.call_count == 0489490provider = pl.CredentialProviderAWS()491492provider()493assert credentials_func.call_count == 1494495provider()496assert credentials_func.call_count == 1497498plmonkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")499500provider()501assert credentials_func.call_count == 2502503provider()504assert credentials_func.call_count == 3505506plmonkeypatch.delenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING")507508provider()509assert credentials_func.call_count == 4510511provider()512assert credentials_func.call_count == 4513514assert provider._cached_credentials.get() is not None515assert pickle.loads(pickle.dumps(provider))._cached_credentials.get() is None516517assert provider() == (518{519"aws_access_key_id": "...",520"aws_secret_access_key": "...",521},522None,523)524525provider()[0]["A"] = "A"526527assert provider() == (528{529"aws_access_key_id": "...",530"aws_secret_access_key": "...",531},532None,533)534535536def test_no_pickle_option() -> None:537v = NoPickleOption(3)538assert v.get() == 3539540out = pickle.loads(pickle.dumps(v))541542assert out.get() is None543544545@pytest.mark.write_disk546def test_credential_provider_aws_expiry(547tmp_path: Path, plmonkeypatch: PlMonkeyPatch548) -> None:549credential_file_path = tmp_path / "credentials.json"550551credential_file_path.write_text(552"""\553{554"Version": 1,555"AccessKeyId": "123",556"SecretAccessKey": "456",557"SessionToken": "789",558"Expiration": "2099-01-01T00:00:00+00:00"559}560"""561)562563cfg_file_path = tmp_path / "config"564565credential_file_path_str = str(credential_file_path).replace("\\", "/")566567cfg_file_path.write_text(f"""\568[profile cred_process]569credential_process = "{sys.executable}" -c "from pathlib import Path; print(Path('{credential_file_path_str}').read_text())"570""")571572plmonkeypatch.setenv("AWS_CONFIG_FILE", str(cfg_file_path))573574creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()575576assert creds == {577"aws_access_key_id": "123",578"aws_secret_access_key": "456",579"aws_session_token": "789",580}581582assert expiry is not None583584assert datetime.fromtimestamp(expiry, tz=timezone.utc) == datetime.fromisoformat(585"2099-01-01T00:00:00+00:00"586)587588credential_file_path.write_text(589"""\590{591"Version": 1,592"AccessKeyId": "...",593"SecretAccessKey": "...",594"SessionToken": "..."595}596"""597)598599creds, expiry = pl.CredentialProviderAWS(profile_name="cred_process")()600601assert creds == {602"aws_access_key_id": "...",603"aws_secret_access_key": "...",604"aws_session_token": "...",605}606607assert expiry is None608609610@pytest.mark.slow611@pytest.mark.parametrize(612(613"credential_provider_class",614"scan_path",615"initial_credentials",616"updated_credentials",617),618[619(620pl.CredentialProviderAWS,621"s3://.../...",622{"aws_access_key_id": "initial", "aws_secret_access_key": "initial"},623{"aws_access_key_id": "updated", "aws_secret_access_key": "updated"},624),625(626pl.CredentialProviderAzure,627"abfss://container@storage_account.dfs.core.windows.net/bucket",628{"bearer_token": "initial"},629{"bearer_token": "updated"},630),631(632pl.CredentialProviderGCP,633"gs://.../...",634{"bearer_token": "initial"},635{"bearer_token": "updated"},636),637],638)639def test_credential_provider_rebuild_clears_cache(640credential_provider_class: type[CachingCredentialProvider],641scan_path: str,642initial_credentials: dict[str, str],643updated_credentials: dict[str, str],644plmonkeypatch: PlMonkeyPatch,645) -> None:646assert initial_credentials != updated_credentials647648plmonkeypatch.setattr(649credential_provider_class,650"retrieve_credentials_impl",651lambda *_: (initial_credentials, None),652)653654storage_options = (655{"aws_endpoint_url": "http://localhost:333"}656if credential_provider_class == pl.CredentialProviderAWS657else None658)659660builder = _init_credential_provider_builder(661"auto",662scan_path,663storage_options=storage_options,664caller_name="test",665)666667assert builder is not None668669# This is a separate one for testing local to this function.670provider_local = credential_provider_class()671672# Set the cache673provider_local()674675# Now update the the retrieval function to return updated credentials.676plmonkeypatch.setattr(677credential_provider_class,678"retrieve_credentials_impl",679lambda *_: (updated_credentials, None),680)681682# Despite "retrieve_credentials_impl" being updated, the providers should683# still return the initial credentials, as they were cached with an expiry684# of None.685assert provider_local() == (initial_credentials, None)686687q = pl.scan_parquet(688scan_path,689storage_options=storage_options,690credential_provider="auto",691)692693with pytest.raises(OSError):694q.collect()695696provider_at_scan = builder.build_credential_provider()697698assert provider_at_scan is not None699assert provider_at_scan() == (updated_credentials, None)700701assert provider_local() == (initial_credentials, None)702703provider_local.clear_cached_credentials()704705assert provider_local() == (updated_credentials, None)706707708def test_user_gcp_token_provider(709plmonkeypatch: PlMonkeyPatch,710) -> None:711provider = UserProvidedGCPToken("A")712assert provider() == ({"bearer_token": "A"}, None)713plmonkeypatch.setenv("POLARS_DISABLE_PYTHON_CREDENTIAL_CACHING", "1")714assert provider() == ({"bearer_token": "A"}, None)715716717def test_auto_init_cache_key_memoize(plmonkeypatch: PlMonkeyPatch) -> None:718get_cache_key_impl = Mock(wraps=AutoInit.get_cache_key_impl)719plmonkeypatch.setattr(720AutoInit,721"get_cache_key_impl",722lambda *a, **kw: get_cache_key_impl(*a, **kw),723)724725v = AutoInit(int)726727assert get_cache_key_impl.call_count == 0728729v.get_or_init_cache_key()730assert get_cache_key_impl.call_count == 1731732v.get_or_init_cache_key()733assert get_cache_key_impl.call_count == 1734735736def test_cached_credential_provider_returns_copied_creds() -> None:737provider_func = Mock(wraps=lambda: ({"A": "A"}, None))738provider = CachedCredentialProvider(provider_func)739740assert provider_func.call_count == 0741742provider()743assert provider() == ({"A": "A"}, None)744745assert provider_func.call_count == 1746747provider()[0]["B"] = "B"748749assert provider() == ({"A": "A"}, None)750751assert provider_func.call_count == 1752753754@pytest.mark.parametrize(755"partition_target",756[757pl.PartitionBy("s3://.../...", key=""),758],759)760def test_credential_provider_init_from_partition_target(761partition_target: pl.PartitionBy,762) -> None:763assert isinstance(764_init_credential_provider_builder(765"auto",766partition_target,767None,768"test",769),770CredentialProviderBuilder,771)772773774@pytest.mark.slow775def test_cache_user_credential_provider(plmonkeypatch: PlMonkeyPatch) -> None:776user_provider = Mock(777return_value=(778{"aws_access_key_id": "...", "aws_secret_access_key": "..."},779None,780)781)782783def get_q() -> pl.LazyFrame:784return pl.scan_parquet(785"s3://.../...",786storage_options={"aws_endpoint_url": "http://localhost:333"},787credential_provider=user_provider,788)789790assert user_provider.call_count == 0791792with pytest.raises(OSError, match="http://localhost:333"):793get_q().collect()794795assert user_provider.call_count == 2796797with pytest.raises(OSError, match="http://localhost:333"):798get_q().collect()799800assert user_provider.call_count == 3801802plmonkeypatch.setenv("POLARS_CREDENTIAL_PROVIDER_BUILDER_CACHE_SIZE", "0")803804with pytest.raises(OSError, match="http://localhost:333"):805get_q().collect()806807assert user_provider.call_count == 5808809810@pytest.mark.slow811def test_credential_provider_global_config(plmonkeypatch: PlMonkeyPatch) -> None:812import polars as pl813import polars.io.cloud.credential_provider._builder814815plmonkeypatch.setattr(816polars.io.cloud.credential_provider._builder,817"DEFAULT_CREDENTIAL_PROVIDER",818None,819)820821provider = Mock(822return_value=(823{"aws_access_key_id": "...", "aws_secret_access_key": "..."},824None,825)826)827828pl.Config.set_default_credential_provider(provider)829830plmonkeypatch.setenv("AWS_ACCESS_KEY_ID", "...")831plmonkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "...")832833def get_q() -> pl.LazyFrame:834return pl.scan_parquet(835"s3://.../...",836storage_options={"aws_endpoint_url": "http://localhost:333"},837)838839def get_q_disable_cred_provider() -> pl.LazyFrame:840return pl.scan_parquet(841"s3://.../...",842credential_provider=None,843storage_options={"aws_endpoint_url": "http://localhost:333"},844)845846assert provider.call_count == 0847848with pytest.raises(OSError, match="http://localhost:333"):849get_q().collect()850851assert provider.call_count == 2852853with pytest.raises(OSError, match="http://localhost:333"):854get_q_disable_cred_provider().collect()855856assert provider.call_count == 2857858with pytest.raises(OSError, match="http://localhost:333"):859get_q().collect()860861assert provider.call_count == 3862863pl.Config.set_default_credential_provider("auto")864865with pytest.raises(OSError, match="http://localhost:333"):866get_q().collect()867868assert provider.call_count == 3869870err_magic = "err_magic_3"871872def raises(*_: None, **__: None) -> None:873raise AssertionError(err_magic)874875plmonkeypatch.setattr(CredentialProviderBuilder, "__init__", raises)876877with pytest.raises(AssertionError, match=err_magic):878get_q().collect()879880pl.Config.set_default_credential_provider(None)881882with pytest.raises(OSError, match="http://localhost:333"):883get_q().collect()884885886