Path: blob/main/py-polars/tests/unit/lazyframe/test_schema.py
8431 views
import pickle1from datetime import datetime2from typing import Any34import pytest56import polars as pl7from polars.datatypes.group import NUMERIC_DTYPES, TEMPORAL_DTYPES8from polars.testing.asserts.frame import assert_frame_equal910# Used by test_lazy_collect_schema_matches_computed_schema11_TEST_COLLECT_SCHEMA_M_DTYPES = sorted(12({pl.Boolean, pl.String} | NUMERIC_DTYPES | TEMPORAL_DTYPES) - {pl.Decimal},13key=repr,14)151617def test_schema() -> None:18s = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})1920assert s["foo"] == pl.Int8()21assert s["bar"] == pl.String()22assert s.len() == 223assert s.names() == ["foo", "bar"]24assert s.dtypes() == [pl.Int8(), pl.String()]2526with pytest.raises(27TypeError,28match="dtypes must be fully-specified, got: List",29):30pl.Schema({"foo": pl.String, "bar": pl.List})313233@pytest.mark.parametrize(34"schema",35[36pl.Schema(),37pl.Schema({"foo": pl.Int8()}),38pl.Schema({"foo": pl.Datetime("us"), "bar": pl.String()}),39pl.Schema(40{41"foo": pl.UInt32(),42"bar": pl.Categorical(),43"baz": pl.Struct({"x": pl.Int64(), "y": pl.Float64()}),44}45),46],47)48def test_schema_empty_frame(schema: pl.Schema) -> None:49assert_frame_equal(50schema.to_frame(),51pl.DataFrame(schema=schema),52)535455def test_schema_equality() -> None:56s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()})57s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})58s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()})5960assert s1 == s161assert s2 == s262assert s3 == s363assert s1 != s264assert s1 != s365assert s2 != s36667s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")})68s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")})69s6 = {"foo": pl.Datetime, "bar": pl.Duration}7071assert s4 != s572assert s4 != s6737475def test_schema_parse_python_dtypes() -> None:76cardinal_directions = pl.Enum(["north", "south", "east", "west"])7778s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type]79s["ham"] = datetime8081assert s["foo"] == pl.List(pl.Int32)82assert s["bar"] == pl.Int6483assert s["baz"] == cardinal_directions84assert s["ham"] == pl.Datetime("us")8586assert s.len() == 487assert s.names() == ["foo", "bar", "baz", "ham"]88assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")]8990assert list(s.to_python().values()) == [list, int, str, datetime]91assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime]929394def test_schema_picklable() -> None:95s = pl.Schema(96{97"foo": pl.Int8(),98"bar": pl.String(),99"ham": pl.Struct({"x": pl.List(pl.Date)}),100}101)102pickled = pickle.dumps(s)103s2 = pickle.loads(pickled)104assert s == s2105106107def test_schema_python() -> None:108input = {109"foo": pl.Int8(),110"bar": pl.String(),111"baz": pl.Categorical(),112"ham": pl.Object(),113"spam": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}),114}115expected = {116"foo": int,117"bar": str,118"baz": str,119"ham": object,120"spam": dict,121}122for schema in (input, input.items(), list(input.items())):123s = pl.Schema(schema)124assert expected == s.to_python()125126127def test_schema_in_map_elements_returns_scalar() -> None:128schema = pl.Schema([("portfolio", pl.String()), ("irr", pl.Float64())])129130ldf = pl.LazyFrame(131{132"portfolio": ["A", "A", "B", "B"],133"amounts": [100.0, -110.0] * 2,134}135)136q = ldf.group_by("portfolio").agg(137pl.col("amounts")138.implode()139.map_elements(lambda x: float(x.sum()), return_dtype=pl.Float64)140.alias("irr")141)142assert q.collect_schema() == schema143assert q.collect().schema == schema144145146@pytest.mark.slow147@pytest.mark.parametrize(148"expr",149[150# TODO: Add more (bitwise) operators once their types are resolved correctly151pl.col("col0") > pl.col("col1"),152pl.col("col0") >= pl.col("col1"),153pl.col("col0") < pl.col("col1"),154pl.col("col0") <= pl.col("col1"),155pl.col("col0") == pl.col("col1"),156pl.col("col0") != pl.col("col1"),157pl.col("col0") + pl.col("col1"),158pl.col("col0") - pl.col("col1"),159pl.col("col0") * pl.col("col1"),160pl.col("col0") / pl.col("col1"),161pl.col("col0").truediv(pl.col("col1")),162pl.col("col0") // pl.col("col1"),163pl.col("col0") % pl.col("col1"),164],165)166@pytest.mark.parametrize("dtype1", _TEST_COLLECT_SCHEMA_M_DTYPES)167@pytest.mark.parametrize("dtype2", _TEST_COLLECT_SCHEMA_M_DTYPES)168def test_lazy_collect_schema_matches_computed_schema(169expr: pl.Expr, dtype1: pl.DataType, dtype2: pl.DataType170) -> None:171df = pl.DataFrame(172{173"col0": [None],174"col1": [None],175},176schema={177"col0": dtype1,178"col1": dtype2,179},180)181lazy_df = df.lazy().select(expr)182183expected_schema = None184try:185expected_schema = lazy_df.collect().schema186except (187# Applying the operator to these dtypes will result in an error,188# so they their output dtype is undefined189pl.exceptions.InvalidOperationError,190pl.exceptions.SchemaError,191pl.exceptions.ComputeError,192):193return194195actual_schema = lazy_df.collect_schema()196assert actual_schema == expected_schema, (197f"{expr} on {df.dtypes} results in {actual_schema} instead of {expected_schema}\n"198f"result of computation is:\n{lazy_df.collect()}\n"199)200201202def test_ir_cache_unique_18198() -> None:203lf = pl.LazyFrame({"a": [1]})204lf.collect_schema()205assert pl.concat([lf, lf]).collect().to_dict(as_series=False) == {"a": [1, 1]}206207208def test_schema_functions_in_agg_with_literal_arg_19011() -> None:209q = (210pl.LazyFrame({"a": [1, 2, 3, None, 5]})211.rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i")212.agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2"))213)214assert q.collect_schema() == pl.Schema(215[("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))]216)217218219def test_lazy_explode_in_agg_schema_19562() -> None:220def new_df_check_schema(221value: dict[str, Any], schema: dict[str, Any]222) -> pl.DataFrame:223df = pl.DataFrame(value)224assert df.schema == schema225return df226227lf = pl.LazyFrame({"a": [1], "b": [[1]]})228229q = lf.group_by("a").agg(pl.col("b"))230schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))}231232assert q.collect_schema() == schema233assert_frame_equal(234q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema)235)236237q = lf.group_by("a").agg(pl.col("b").explode())238schema = {"a": pl.Int64, "b": pl.List(pl.Int64)}239240assert q.collect_schema() == schema241assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema))242243q = lf.group_by("a").agg(pl.col("b").explode().explode())244schema = {"a": pl.Int64, "b": pl.List(pl.Int64)}245246assert q.collect_schema() == schema247assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema))248249# 2x nested250lf = pl.LazyFrame({"a": [1], "b": [[[1]]]})251252q = lf.group_by("a").agg(pl.col("b"))253schema = {254"a": pl.Int64,255"b": pl.List(pl.List(pl.List(pl.Int64))),256}257258assert q.collect_schema() == schema259assert_frame_equal(260q.collect(), new_df_check_schema({"a": [1], "b": [[[[1]]]]}, schema)261)262263q = lf.group_by("a").agg(pl.col("b").explode())264schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))}265266assert q.collect_schema() == schema267assert_frame_equal(268q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema)269)270271q = lf.group_by("a").agg(pl.col("b").explode().explode())272schema = {"a": pl.Int64, "b": pl.List(pl.Int64)}273274assert q.collect_schema() == schema275assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema))276277278def test_lazy_nested_function_expr_agg_schema() -> None:279q = (280pl.LazyFrame({"k": [1, 1, 2]})281.group_by(pl.first(), maintain_order=True)282.agg(o=pl.int_range(pl.len()).reverse() < 1)283)284285assert q.collect_schema() == {"k": pl.Int64, "o": pl.List(pl.Boolean)}286assert_frame_equal(287q.collect(), pl.DataFrame({"k": [1, 2], "o": [[False, True], [True]]})288)289290291def test_lazy_agg_scalar_return_schema() -> None:292q = pl.LazyFrame({"k": [1]}).group_by("k").agg(pl.col("k").null_count().alias("o"))293294schema = {"k": pl.Int64, "o": pl.get_index_type()}295assert q.collect_schema() == schema296assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema))297298299def test_lazy_agg_nested_expr_schema() -> None:300q = (301pl.LazyFrame({"k": [1]})302.group_by("k")303.agg(304(305(306(pl.col("k").reverse().shuffle() + 1)307+ pl.col("k").shuffle().reverse()308)309.shuffle()310.reverse()311.sum()312* 0313).alias("o")314)315)316317schema = {"k": pl.Int64, "o": pl.Int64}318assert q.collect_schema() == schema319assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema))320321322def test_lazy_agg_lit_explode() -> None:323q = (324pl.LazyFrame({"k": [1]})325.group_by("k")326.agg(pl.lit(1, dtype=pl.Int64).explode().alias("o"))327)328329schema = {"k": pl.Int64, "o": pl.List(pl.Int64)}330assert q.collect_schema() == schema331assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": [[1]]}, schema=schema)) # type: ignore[arg-type]332333334@pytest.mark.parametrize(335"expr_op", [336"approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or",337"bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis",338"last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max",339"nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound",340"var"341]342) # fmt: skip343@pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")])344def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None:345op = getattr(pl.Expr, expr_op)346347lf = pl.LazyFrame({"a": 1, "b": 1})348349q = lf.group_by("a").agg(lhs.reverse().pipe(op))350assert q.collect_schema() == q.collect().collect_schema()351352q = lf.group_by("a").agg(lhs.shuffle().reverse().pipe(op))353354assert q.collect_schema() == q.collect().collect_schema()355356357def test_lazy_agg_schema_after_elementwise_19984() -> None:358lf = pl.LazyFrame({"a": 1, "b": 1})359360q = lf.group_by("a").agg(pl.col("b").item().fill_null(0))361assert q.collect_schema() == q.collect().collect_schema()362363q = lf.group_by("a").agg(pl.col("b").item().fill_null(0).fill_null(0))364assert q.collect_schema() == q.collect().collect_schema()365366q = lf.group_by("a").agg(pl.col("b").item() + 1)367assert q.collect_schema() == q.collect().collect_schema()368369q = lf.group_by("a").agg(1 + pl.col("b").item())370assert q.collect_schema() == q.collect().collect_schema()371372373@pytest.mark.parametrize(374"expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()]375)376@pytest.mark.parametrize("mapping_strategy", ["explode", "join", "group_to_rows"])377def test_lazy_window_schema(expr: pl.Expr, mapping_strategy: str) -> None:378q = pl.LazyFrame({"a": 1, "b": 1}).select(379expr.over("a", mapping_strategy=mapping_strategy) # type: ignore[arg-type]380)381382assert q.collect_schema() == q.collect().collect_schema()383384385def test_lazy_explode_schema() -> None:386lf = pl.LazyFrame({"k": [1], "x": pl.Series([[1]], dtype=pl.Array(pl.Int64, 1))})387388q = lf.select(pl.col("x").explode())389assert q.collect_schema() == {"x": pl.Int64}390391q = lf.select(pl.col("x").arr.explode())392assert q.collect_schema() == {"x": pl.Int64}393394lf = pl.LazyFrame({"k": [1], "x": pl.Series([[1]], dtype=pl.List(pl.Int64))})395396q = lf.select(pl.col("x").explode())397assert q.collect_schema() == {"x": pl.Int64}398399q = lf.select(pl.col("x").list.explode())400assert q.collect_schema() == {"x": pl.Int64}401402# `LazyFrame.explode()` goes through a different codepath than `Expr.expode`403lf = pl.LazyFrame().with_columns(404pl.Series([[1]], dtype=pl.List(pl.Int64)).alias("list"),405pl.Series([[1]], dtype=pl.Array(pl.Int64, 1)).alias("array"),406)407408q = lf.explode("*")409assert q.collect_schema() == {"list": pl.Int64, "array": pl.Int64}410411q = lf.explode("list")412assert q.collect_schema() == {"list": pl.Int64, "array": pl.Array(pl.Int64, 1)}413414415def test_raise_subnodes_18787() -> None:416df = pl.DataFrame({"a": [1], "b": [2]})417418with pytest.raises(pl.exceptions.ColumnNotFoundError):419(420df.select(pl.struct(pl.all())).select(421pl.first().struct.field("a", "b").filter(pl.col("foo") == 1)422)423)424425426def test_scalar_agg_schema_20044() -> None:427assert (428pl.DataFrame(None, schema={"a": pl.Int64, "b": pl.String, "c": pl.String})429.with_columns(d=pl.col("a").max())430.group_by("c")431.agg(pl.col("d").mean())432).schema == pl.Schema([("c", pl.String), ("d", pl.Float64)])433434435@pytest.mark.parametrize(436"df",437[438pl.DataFrame({"a": [None, True, False], "b": 3 * [128]}),439pl.DataFrame(440{"a": [[None, True, False]], "b": [3 * [128]]},441schema={"a": pl.Array(pl.Boolean, 3), "b": pl.Array(pl.Int64, 3)},442),443pl.DataFrame(444{"a": [[None, True, False]], "b": [3 * [128]]},445schema={"a": pl.List(pl.Boolean), "b": pl.List(pl.Int64)},446),447],448)449def test_div_collect_schema_matches_23993(df: pl.DataFrame) -> None:450q = df.lazy().select(pl.col("a") / pl.col("b"))451expected = q.collect().schema452actual = q.collect_schema()453assert actual == expected454455456