Path: blob/main/py-polars/tests/unit/sql/test_structs.py
6939 views
from __future__ import annotations12import pytest34import polars as pl5from polars.exceptions import (6SQLInterfaceError,7SQLSyntaxError,8StructFieldNotFoundError,9)10from polars.testing import assert_frame_equal111213@pytest.fixture14def df_struct() -> pl.DataFrame:15return pl.DataFrame(16{17"id": [200, 300, 400],18"name": ["Bob", "David", "Zoe"],19"age": [45, 19, 45],20"other": [{"n": 1.5}, {"n": None}, {"n": -0.5}],21}22).select(pl.struct(pl.all()).alias("json_msg"))232425def test_struct_field_nested_dot_notation_22107() -> None:26# ensure dot-notation references the given name at the right level of nesting27df = pl.DataFrame(28{29"id": ["012345", "987654"],30"name": ["A Book", "Another Book"],31"author": [32{"id": "888888", "name": "Iain M. Banks"},33{"id": "444444", "name": "Dan Abnett"},34],35}36)3738res = df.sql("SELECT id, author.id AS author_id FROM self ORDER BY id")39assert res.to_dict(as_series=False) == {40"id": ["012345", "987654"],41"author_id": ["888888", "444444"],42}4344for name in ("author.name", "self.author.name"):45res = df.sql(f"SELECT {name} FROM self ORDER BY id")46assert res.to_dict(as_series=False) == {"name": ["Iain M. Banks", "Dan Abnett"]}4748for name in ("name", "self.name"):49res = df.sql(f"SELECT {name} FROM self ORDER BY self.id DESC")50assert res.to_dict(as_series=False) == {"name": ["Another Book", "A Book"]}5152# expected errors53with pytest.raises(54SQLInterfaceError,55match="no table or struct column named 'foo' found",56):57df.sql("SELECT foo.id FROM self ORDER BY id")5859with pytest.raises(60SQLInterfaceError,61match="no column named 'foo' found",62):63df.sql("SELECT self.foo FROM self ORDER BY id")646566@pytest.mark.parametrize(67"order_by",68[69"ORDER BY json_msg.id DESC",70"ORDER BY 2 DESC",71"",72],73)74def test_struct_field_selection(order_by: str, df_struct: pl.DataFrame) -> None:75res = df_struct.sql(76f"""77SELECT78-- validate table alias resolution79frame.json_msg.id AS ID,80self.json_msg.name AS NAME,81json_msg.age AS AGE82FROM83self AS frame84WHERE85json_msg.age > 20 AND86json_msg.other.n IS NOT NULL -- note: nested struct field87{order_by}88"""89)90if not order_by:91res = res.sort(by="ID", descending=True)9293expected = pl.DataFrame({"ID": [400, 200], "NAME": ["Zoe", "Bob"], "AGE": [45, 45]})94assert_frame_equal(expected, res)959697def test_struct_field_group_by(df_struct: pl.DataFrame) -> None:98res = pl.sql(99"""100SELECT101COUNT(json_msg.age) AS n,102ARRAY_AGG(json_msg.name) AS names103FROM df_struct104GROUP BY json_msg.age105ORDER BY 1 DESC106"""107).collect()108109expected = pl.DataFrame(110data={"n": [2, 1], "names": [["Bob", "Zoe"], ["David"]]},111schema_overrides={"n": pl.UInt32},112)113assert_frame_equal(expected, res)114115116def test_struct_field_group_by_errors(df_struct: pl.DataFrame) -> None:117with pytest.raises(118SQLSyntaxError,119match="'name' should participate in the GROUP BY clause or an aggregate function",120):121pl.sql(122"""123SELECT124json_msg.name,125SUM(json_msg.age) AS sum_age126FROM df_struct127GROUP BY json_msg.age128"""129).collect()130131132@pytest.mark.parametrize(133("expr", "expected"),134[135("nested #> '{c,1}'", 2),136("nested #> '{c,-1}'", 1),137("nested #>> '{c,0}'", "3"),138("nested -> '0' -> 0", "baz"),139("nested -> 'c' -> -1", 1),140("nested -> 'c' ->> 2", "1"),141],142)143def test_struct_field_operator_access(expr: str, expected: int | str) -> None:144df = pl.DataFrame(145{146"nested": {147"0": ["baz"],148"b": ["foo", "bar"],149"c": [3, 2, 1],150},151},152)153assert df.sql(f"SELECT {expr} FROM self").item() == expected154155156@pytest.mark.parametrize(157("fields", "excluding", "rename"),158[159("json_msg.*", "age", {}),160("json_msg.*", "name", {"other": "misc"}),161("self.json_msg.*", "(age,other)", {"name": "ident"}),162("json_msg.other.*", "", {"n": "num"}),163("self.json_msg.other.*", "", {}),164("self.json_msg.other.*", "n", {}),165],166)167def test_struct_field_selection_wildcards(168fields: str,169excluding: str,170rename: dict[str, str],171df_struct: pl.DataFrame,172) -> None:173exclude_cols = f"EXCLUDE {excluding}" if excluding else ""174rename_cols = (175f"RENAME ({','.join(f'{k} AS {v}' for k, v in rename.items())})"176if rename177else ""178)179res = df_struct.sql(180f"""181SELECT {fields} {exclude_cols} {rename_cols}182FROM self ORDER BY json_msg.id183"""184)185186expected = df_struct.unnest("json_msg")187if fields.endswith(".other.*"):188expected = expected["other"].struct.unnest()189if excluding:190expected = expected.drop(excluding.strip(")(").split(","))191if rename:192expected = expected.rename(rename)193194assert_frame_equal(expected, res)195196197@pytest.mark.parametrize(198("invalid_column", "error_type"),199[200("json_msg.invalid_column", StructFieldNotFoundError),201("json_msg.other.invalid_column", StructFieldNotFoundError),202("self.json_msg.other.invalid_column", StructFieldNotFoundError),203("json_msg.other -> invalid_column", SQLSyntaxError),204("json_msg -> DATE '2020-09-11'", SQLSyntaxError),205],206)207def test_struct_field_selection_errors(208invalid_column: str,209error_type: type[Exception],210df_struct: pl.DataFrame,211) -> None:212error_msg = (213"invalid json/struct path-extract"214if ("->" in invalid_column)215else "invalid_column"216)217with pytest.raises(error_type, match=error_msg):218df_struct.sql(f"SELECT {invalid_column} FROM self")219220221