Path: blob/main/py-polars/tests/unit/io/cloud/test_aws.py
6939 views
from __future__ import annotations12import multiprocessing3from typing import TYPE_CHECKING, Any, Callable45import boto36import pytest7from moto.server import ThreadedMotoServer89import polars as pl10from polars.testing import assert_frame_equal1112if TYPE_CHECKING:13from collections.abc import Iterator14from pathlib import Path1516pytestmark = [17pytest.mark.skip(18reason="Causes intermittent failures in CI. See: "19"https://github.com/pola-rs/polars/issues/16910"20),21pytest.mark.xdist_group("aws"),22pytest.mark.slow(),23]242526@pytest.fixture(scope="module")27def monkeypatch_module() -> Any:28"""Allow module-scoped monkeypatching."""29with pytest.MonkeyPatch.context() as mp:30yield mp313233@pytest.fixture(scope="module")34def s3_base(monkeypatch_module: Any) -> Iterator[str]:35monkeypatch_module.setenv("AWS_ACCESS_KEY_ID", "accesskey")36monkeypatch_module.setenv("AWS_SECRET_ACCESS_KEY", "secretkey")37monkeypatch_module.setenv("AWS_DEFAULT_REGION", "us-east-1")3839host = "127.0.0.1"40port = 500041moto_server = ThreadedMotoServer(host, port)42# Start in a separate process to avoid deadlocks43mp = multiprocessing.get_context("spawn")44p = mp.Process(target=moto_server._server_entry, daemon=True)45p.start()46print("server up")47yield f"http://{host}:{port}"48print("moto done")49p.kill()505152@pytest.fixture53def s3(s3_base: str, io_files_path: Path) -> str:54region = "us-east-1"55client = boto3.client("s3", region_name=region, endpoint_url=s3_base)56client.create_bucket(Bucket="bucket")5758files = ["foods1.csv", "foods1.ipc", "foods1.parquet", "foods2.parquet"]59for file in files:60client.upload_file(io_files_path / file, Bucket="bucket", Key=file)61return s3_base626364@pytest.mark.parametrize(65("function", "extension"),66[67(pl.read_csv, "csv"),68(pl.read_ipc, "ipc"),69],70)71def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None:72storage_options = {"endpoint_url": s3}73df = function(74f"s3://bucket/foods1.{extension}",75storage_options=storage_options,76)77assert df.columns == ["category", "calories", "fats_g", "sugars_g"]78assert df.shape == (27, 4)7980# ensure we aren't modifying the original user dictionary (ref #15859)81assert storage_options == {"endpoint_url": s3}828384@pytest.mark.parametrize(85("function", "extension"),86[(pl.scan_ipc, "ipc"), (pl.scan_parquet, "parquet")],87)88def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None:89lf = function(90f"s3://bucket/foods1.{extension}",91storage_options={"endpoint_url": s3},92)93assert lf.collect_schema().names() == ["category", "calories", "fats_g", "sugars_g"]94assert lf.collect().shape == (27, 4)959697def test_lazy_count_s3(s3: str) -> None:98lf = pl.scan_parquet(99"s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3}100).select(pl.len())101102assert "FAST_COUNT" in lf.explain()103expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32})104assert_frame_equal(lf.collect(), expected)105106107