Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/io/cloud/test_aws.py
6939 views
1
from __future__ import annotations
2
3
import multiprocessing
4
from typing import TYPE_CHECKING, Any, Callable
5
6
import boto3
7
import pytest
8
from moto.server import ThreadedMotoServer
9
10
import polars as pl
11
from polars.testing import assert_frame_equal
12
13
if TYPE_CHECKING:
14
from collections.abc import Iterator
15
from pathlib import Path
16
17
pytestmark = [
18
pytest.mark.skip(
19
reason="Causes intermittent failures in CI. See: "
20
"https://github.com/pola-rs/polars/issues/16910"
21
),
22
pytest.mark.xdist_group("aws"),
23
pytest.mark.slow(),
24
]
25
26
27
@pytest.fixture(scope="module")
28
def monkeypatch_module() -> Any:
29
"""Allow module-scoped monkeypatching."""
30
with pytest.MonkeyPatch.context() as mp:
31
yield mp
32
33
34
@pytest.fixture(scope="module")
35
def s3_base(monkeypatch_module: Any) -> Iterator[str]:
36
monkeypatch_module.setenv("AWS_ACCESS_KEY_ID", "accesskey")
37
monkeypatch_module.setenv("AWS_SECRET_ACCESS_KEY", "secretkey")
38
monkeypatch_module.setenv("AWS_DEFAULT_REGION", "us-east-1")
39
40
host = "127.0.0.1"
41
port = 5000
42
moto_server = ThreadedMotoServer(host, port)
43
# Start in a separate process to avoid deadlocks
44
mp = multiprocessing.get_context("spawn")
45
p = mp.Process(target=moto_server._server_entry, daemon=True)
46
p.start()
47
print("server up")
48
yield f"http://{host}:{port}"
49
print("moto done")
50
p.kill()
51
52
53
@pytest.fixture
54
def s3(s3_base: str, io_files_path: Path) -> str:
55
region = "us-east-1"
56
client = boto3.client("s3", region_name=region, endpoint_url=s3_base)
57
client.create_bucket(Bucket="bucket")
58
59
files = ["foods1.csv", "foods1.ipc", "foods1.parquet", "foods2.parquet"]
60
for file in files:
61
client.upload_file(io_files_path / file, Bucket="bucket", Key=file)
62
return s3_base
63
64
65
@pytest.mark.parametrize(
66
("function", "extension"),
67
[
68
(pl.read_csv, "csv"),
69
(pl.read_ipc, "ipc"),
70
],
71
)
72
def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None:
73
storage_options = {"endpoint_url": s3}
74
df = function(
75
f"s3://bucket/foods1.{extension}",
76
storage_options=storage_options,
77
)
78
assert df.columns == ["category", "calories", "fats_g", "sugars_g"]
79
assert df.shape == (27, 4)
80
81
# ensure we aren't modifying the original user dictionary (ref #15859)
82
assert storage_options == {"endpoint_url": s3}
83
84
85
@pytest.mark.parametrize(
86
("function", "extension"),
87
[(pl.scan_ipc, "ipc"), (pl.scan_parquet, "parquet")],
88
)
89
def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None:
90
lf = function(
91
f"s3://bucket/foods1.{extension}",
92
storage_options={"endpoint_url": s3},
93
)
94
assert lf.collect_schema().names() == ["category", "calories", "fats_g", "sugars_g"]
95
assert lf.collect().shape == (27, 4)
96
97
98
def test_lazy_count_s3(s3: str) -> None:
99
lf = pl.scan_parquet(
100
"s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3}
101
).select(pl.len())
102
103
assert "FAST_COUNT" in lf.explain()
104
expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32})
105
assert_frame_equal(lf.collect(), expected)
106
107