Path: blob/main/py-polars/tests/unit/operations/test_cut.py
6939 views
from __future__ import annotations12import pytest34import polars as pl5from polars.testing import assert_frame_equal, assert_series_equal67inf = float("inf")8910def test_cut() -> None:11s = pl.Series("a", [-2, -1, 0, 1, 2])1213result = s.cut([-1, 1])1415expected = pl.Series(16"a",17[18"(-inf, -1]",19"(-inf, -1]",20"(-1, 1]",21"(-1, 1]",22"(1, inf]",23],24dtype=pl.Categorical,25)26assert_series_equal(result, expected, categorical_as_str=True)272829def test_cut_lazy_schema() -> None:30lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]})3132result = lf.select(pl.col("a").cut([-1, 1]))3334expected = pl.LazyFrame(35{"a": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"]},36schema={"a": pl.Categorical},37)38assert_frame_equal(result, expected, categorical_as_str=True)394041def test_cut_include_breaks() -> None:42s = pl.Series("a", [-2, -1, 0, 1, 2])4344out = s.cut([-1.5, 0.25, 1.0], labels=["a", "b", "c", "d"], include_breaks=True)4546expected = pl.DataFrame(47{48"breakpoint": [-1.5, 0.25, 0.25, 1.0, inf],49"category": ["a", "b", "b", "c", "d"],50},51schema_overrides={"category": pl.Categorical},52).to_struct("a")53assert_series_equal(out, expected, categorical_as_str=True)545556# https://github.com/pola-rs/polars/issues/1125557def test_cut_include_breaks_lazy_schema() -> None:58lf = pl.LazyFrame({"a": [-2, -1, 0, 1, 2]})5960result = lf.select(61pl.col("a").cut([-1, 1], include_breaks=True).alias("cut")62).unnest("cut")6364expected = pl.LazyFrame(65{66"breakpoint": [-1.0, -1.0, 1.0, 1.0, inf],67"category": ["(-inf, -1]", "(-inf, -1]", "(-1, 1]", "(-1, 1]", "(1, inf]"],68},69schema_overrides={"category": pl.Categorical},70)71assert_frame_equal(result, expected, categorical_as_str=True)727374def test_cut_null_values() -> None:75s = pl.Series([-1.0, None, 1.0, 2.0, None, 8.0, 4.0])7677result = s.cut([1.5, 5.0], labels=["a", "b", "c"])7879expected = pl.Series(["a", None, "a", "b", None, "c", "b"], dtype=pl.Categorical)80assert_series_equal(result, expected, categorical_as_str=True)818283def test_cut_bin_name_in_agg_context() -> None:84df = pl.DataFrame({"a": [1]}).select(85cut=pl.col("a").cut([1, 2], include_breaks=True).over(1),86qcut=pl.col("a").qcut([1], include_breaks=True).over(1),87qcut_uniform=pl.col("a").qcut(1, include_breaks=True).over(1),88)89schema = pl.Struct({"breakpoint": pl.Float64, "category": pl.Categorical()})90assert df.schema == {"cut": schema, "qcut": schema, "qcut_uniform": schema}919293@pytest.mark.parametrize(94("breaks", "expected_labels", "expected_unique"),95[96(97[2, 4],98pl.Series("x", ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"]),993,100),101(102[99, 101],103pl.Series("x", 5 * ["(-inf, 99]"]),1041,105),106],107)108def test_cut_fast_unique_15981(109breaks: list[int],110expected_labels: pl.Series,111expected_unique: int,112) -> None:113s = pl.Series("x", [1, 2, 3, 4, 5])114115include_breaks = False116s_cut = s.cut(breaks, include_breaks=include_breaks)117118assert_series_equal(s_cut.cast(pl.String), expected_labels)119assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique120s_cut.to_frame().group_by(s.name).len()121122include_breaks = True123s_cut = (124s.cut(breaks, include_breaks=include_breaks).struct.field("category").alias("x")125)126127assert_series_equal(s_cut.cast(pl.String), expected_labels)128assert s_cut.n_unique() == s_cut.to_physical().n_unique() == expected_unique129s_cut.to_frame().group_by(s.name).len()130131132