Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/cloud/test_aws.py
8424 views
1
from __future__ import annotations
2
3
import multiprocessing
4
from typing import TYPE_CHECKING, Any
5
6
import boto3
7
import pytest
8
from moto.server import ThreadedMotoServer
9
10
import polars as pl
11
from polars.testing import assert_frame_equal
12
from tests.conftest import PlMonkeyPatch
13
14
if TYPE_CHECKING:
15
from collections.abc import Callable, Iterator
16
from pathlib import Path
17
18
pytestmark = [
19
pytest.mark.skip(
20
reason="Causes intermittent failures in CI. See: "
21
"https://github.com/pola-rs/polars/issues/16910"
22
),
23
pytest.mark.xdist_group("aws"),
24
pytest.mark.slow(),
25
]
26
27
28
@pytest.fixture(scope="module")
29
def monkeypatch_module() -> Any:
30
"""Allow module-scoped monkeypatching."""
31
with PlMonkeyPatch.context() as mp:
32
yield mp
33
34
35
@pytest.fixture(scope="module")
36
def s3_base(monkeypatch_module: Any) -> Iterator[str]:
37
monkeypatch_module.setenv("AWS_ACCESS_KEY_ID", "accesskey")
38
monkeypatch_module.setenv("AWS_SECRET_ACCESS_KEY", "secretkey")
39
monkeypatch_module.setenv("AWS_DEFAULT_REGION", "us-east-1")
40
41
host = "127.0.0.1"
42
port = 5000
43
moto_server = ThreadedMotoServer(host, port)
44
# Start in a separate process to avoid deadlocks
45
mp = multiprocessing.get_context("spawn")
46
p = mp.Process(target=moto_server._server_entry, daemon=True)
47
p.start()
48
print("server up")
49
yield f"http://{host}:{port}"
50
print("moto done")
51
p.kill()
52
53
54
@pytest.fixture
55
def s3(s3_base: str, io_files_path: Path) -> str:
56
region = "us-east-1"
57
client = boto3.client("s3", region_name=region, endpoint_url=s3_base)
58
client.create_bucket(Bucket="bucket")
59
60
files = ["foods1.csv", "foods1.ipc", "foods1.parquet", "foods2.parquet"]
61
for file in files:
62
client.upload_file(io_files_path / file, Bucket="bucket", Key=file)
63
return s3_base
64
65
66
@pytest.mark.parametrize(
67
("function", "extension"),
68
[
69
(pl.read_csv, "csv"),
70
(pl.read_ipc, "ipc"),
71
],
72
)
73
def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None:
74
storage_options = {"endpoint_url": s3}
75
df = function(
76
f"s3://bucket/foods1.{extension}",
77
storage_options=storage_options,
78
)
79
assert df.columns == ["category", "calories", "fats_g", "sugars_g"]
80
assert df.shape == (27, 4)
81
82
# ensure we aren't modifying the original user dictionary (ref #15859)
83
assert storage_options == {"endpoint_url": s3}
84
85
86
@pytest.mark.parametrize(
87
("function", "extension"),
88
[(pl.scan_ipc, "ipc"), (pl.scan_parquet, "parquet")],
89
)
90
def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None:
91
lf = function(
92
f"s3://bucket/foods1.{extension}",
93
storage_options={"endpoint_url": s3},
94
)
95
assert lf.collect_schema().names() == ["category", "calories", "fats_g", "sugars_g"]
96
assert lf.collect().shape == (27, 4)
97
98
99
def test_lazy_count_s3(s3: str) -> None:
100
lf = pl.scan_parquet(
101
"s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3}
102
).select(pl.len())
103
104
assert "FAST_COUNT" in lf.explain()
105
expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32})
106
assert_frame_equal(lf.collect(), expected)
107
108
109
def test_read_parquet_metadata(s3: str) -> None:
110
metadata = pl.read_parquet_metadata(
111
"s3://bucket/foods1.parquet", storage_options={"endpoint_url": s3}
112
)
113
assert "ARROW:schema" in metadata
114
115