Path: blob/main/py-polars/tests/unit/io/cloud/test_aws.py
8424 views
from __future__ import annotations12import multiprocessing3from typing import TYPE_CHECKING, Any45import boto36import pytest7from moto.server import ThreadedMotoServer89import polars as pl10from polars.testing import assert_frame_equal11from tests.conftest import PlMonkeyPatch1213if TYPE_CHECKING:14from collections.abc import Callable, Iterator15from pathlib import Path1617pytestmark = [18pytest.mark.skip(19reason="Causes intermittent failures in CI. See: "20"https://github.com/pola-rs/polars/issues/16910"21),22pytest.mark.xdist_group("aws"),23pytest.mark.slow(),24]252627@pytest.fixture(scope="module")28def monkeypatch_module() -> Any:29"""Allow module-scoped monkeypatching."""30with PlMonkeyPatch.context() as mp:31yield mp323334@pytest.fixture(scope="module")35def s3_base(monkeypatch_module: Any) -> Iterator[str]:36monkeypatch_module.setenv("AWS_ACCESS_KEY_ID", "accesskey")37monkeypatch_module.setenv("AWS_SECRET_ACCESS_KEY", "secretkey")38monkeypatch_module.setenv("AWS_DEFAULT_REGION", "us-east-1")3940host = "127.0.0.1"41port = 500042moto_server = ThreadedMotoServer(host, port)43# Start in a separate process to avoid deadlocks44mp = multiprocessing.get_context("spawn")45p = mp.Process(target=moto_server._server_entry, daemon=True)46p.start()47print("server up")48yield f"http://{host}:{port}"49print("moto done")50p.kill()515253@pytest.fixture54def s3(s3_base: str, io_files_path: Path) -> str:55region = "us-east-1"56client = boto3.client("s3", region_name=region, endpoint_url=s3_base)57client.create_bucket(Bucket="bucket")5859files = ["foods1.csv", "foods1.ipc", "foods1.parquet", "foods2.parquet"]60for file in files:61client.upload_file(io_files_path / file, Bucket="bucket", Key=file)62return s3_base636465@pytest.mark.parametrize(66("function", "extension"),67[68(pl.read_csv, "csv"),69(pl.read_ipc, "ipc"),70],71)72def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None:73storage_options = {"endpoint_url": s3}74df = function(75f"s3://bucket/foods1.{extension}",76storage_options=storage_options,77)78assert df.columns == ["category", "calories", "fats_g", "sugars_g"]79assert df.shape == (27, 4)8081# ensure we aren't modifying the original user dictionary (ref #15859)82assert storage_options == {"endpoint_url": s3}838485@pytest.mark.parametrize(86("function", "extension"),87[(pl.scan_ipc, "ipc"), (pl.scan_parquet, "parquet")],88)89def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None:90lf = function(91f"s3://bucket/foods1.{extension}",92storage_options={"endpoint_url": s3},93)94assert lf.collect_schema().names() == ["category", "calories", "fats_g", "sugars_g"]95assert lf.collect().shape == (27, 4)969798def test_lazy_count_s3(s3: str) -> None:99lf = pl.scan_parquet(100"s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3}101).select(pl.len())102103assert "FAST_COUNT" in lf.explain()104expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32})105assert_frame_equal(lf.collect(), expected)106107108def test_read_parquet_metadata(s3: str) -> None:109metadata = pl.read_parquet_metadata(110"s3://bucket/foods1.parquet", storage_options={"endpoint_url": s3}111)112assert "ARROW:schema" in metadata113114115