Path: blob/main/py-polars/tests/unit/operations/namespaces/test_categorical.py
8445 views
from __future__ import annotations12from io import BytesIO3from typing import TYPE_CHECKING45import pytest67import polars as pl8from polars.testing import assert_frame_equal, assert_series_equal910if TYPE_CHECKING:11from polars._typing import PolarsDataType121314def test_categorical_lexical_sort() -> None:15df = pl.DataFrame(16{"cats": ["z", "z", "k", "a", "b"], "vals": [3, 1, 2, 2, 3]}17).with_columns(18pl.col("cats").cast(pl.Categorical()),19)2021out = df.sort(["cats"])22assert out["cats"].dtype == pl.Categorical23expected = pl.DataFrame(24{"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 3, 1]}25)26assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)27out = df.sort(["cats", "vals"])28expected = pl.DataFrame(29{"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 1, 3]}30)31assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)32out = df.sort(["vals", "cats"])3334expected = pl.DataFrame(35{"cats": ["z", "a", "k", "b", "z"], "vals": [1, 2, 2, 3, 3]}36)37assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)3839s = pl.Series(["a", "c", "a", "b", "a"], dtype=pl.Categorical())40assert s.sort().cast(pl.String).to_list() == [41"a",42"a",43"a",44"b",45"c",46]474849def test_categorical_lexical_ordering_after_concat() -> None:50ldf1 = (51pl.DataFrame([pl.Series("key1", [8, 5]), pl.Series("key2", ["fox", "baz"])])52.lazy()53.with_columns(pl.col("key2").cast(pl.Categorical()))54)55ldf2 = (56pl.DataFrame(57[pl.Series("key1", [6, 8, 6]), pl.Series("key2", ["fox", "foo", "bar"])]58)59.lazy()60.with_columns(pl.col("key2").cast(pl.Categorical()))61)62df = pl.concat([ldf1, ldf2]).select(pl.col("key2")).collect()6364assert df.sort("key2").to_dict(as_series=False) == {65"key2": ["bar", "baz", "foo", "fox", "fox"]66}676869def test_sort_categoricals_6014_lexical() -> None:70# create lexically-ordered categorical71df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns(72pl.col("key").cast(pl.Categorical())73)7475out = df.sort("key")76assert out.to_dict(as_series=False) == {"key": ["aaa", "bbb", "ccc"]}777879def test_categorical_get_categories() -> None:80s = pl.Series("cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical)81assert set(s.cat.get_categories().to_list()) >= {"foo", "bar", "ham"}828384def test_cat_to_local() -> None:85s = pl.Series(["a", "b", "a"], dtype=pl.Categorical)86assert_series_equal(s, s.cat.to_local())878889def test_cat_uses_lexical_ordering() -> None:90with pytest.warns(DeprecationWarning, match="ordering parameter"):91physical_cat = pl.Categorical(ordering="physical")9293for dtype in [pl.Categorical, pl.Categorical(), physical_cat]:94s = pl.Series(["a", "b", None, "b"]).cast(dtype) # type: ignore[arg-type]9596with pytest.warns(97DeprecationWarning,98match="Categoricals are now always ordered lexically",99):100assert s.cat.uses_lexical_ordering()101102103@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])104def test_cat_len_bytes(dtype: PolarsDataType) -> None:105# test Series106values = ["Café", None, "Café", "345", "東京"]107if dtype == pl.Enum:108dtype = pl.Enum(list({x for x in values if x is not None}))109s = pl.Series("a", values, dtype=dtype)110result = s.cat.len_bytes()111expected = pl.Series("a", [5, None, 5, 3, 6], dtype=pl.UInt32)112assert_series_equal(result, expected)113114# test DataFrame expr115df = pl.DataFrame(s)116result_df = df.select(pl.col("a").cat.len_bytes())117expected_df = pl.DataFrame(expected)118assert_frame_equal(result_df, expected_df)119120# test LazyFrame expr121result_lf = df.lazy().select(pl.col("a").cat.len_bytes()).collect()122assert_frame_equal(result_lf, expected_df)123124# test GroupBy125result_df = (126pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)})127.group_by("key", maintain_order=True)128.agg(pl.col("value").cat.len_bytes().alias("len_bytes"))129.explode("len_bytes")130.collect()131)132expected_df = pl.DataFrame(133{134"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],135"len_bytes": pl.Series(136[5, None, 5, 3, 6, 5, None, 5, 3, 6], dtype=pl.UInt32137),138}139)140assert_frame_equal(result_df, expected_df)141142143@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])144def test_cat_len_chars(dtype: PolarsDataType) -> None:145values = ["Café", None, "Café", "345", "東京"]146if dtype == pl.Enum:147dtype = pl.Enum(list({x for x in values if x is not None}))148# test Series149s = pl.Series("a", values, dtype=dtype)150result = s.cat.len_chars()151expected = pl.Series("a", [4, None, 4, 3, 2], dtype=pl.UInt32)152assert_series_equal(result, expected)153154# test DataFrame expr155df = pl.DataFrame(s)156result_df = df.select(pl.col("a").cat.len_chars())157expected_df = pl.DataFrame(expected)158assert_frame_equal(result_df, expected_df)159160# test LazyFrame expr161result_lf = df.lazy().select(pl.col("a").cat.len_chars()).collect()162assert_frame_equal(result_lf, expected_df)163164# test GroupBy165result_df = (166pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)})167.group_by("key", maintain_order=True)168.agg(pl.col("value").cat.len_chars().alias("len_bytes"))169.explode("len_bytes")170.collect()171)172expected_df = pl.DataFrame(173{174"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],175"len_bytes": pl.Series(176[4, None, 4, 3, 2, 4, None, 4, 3, 2], dtype=pl.UInt32177),178}179)180assert_frame_equal(result_df, expected_df)181182183@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])184def test_starts_ends_with(dtype: PolarsDataType) -> None:185values = ["hamburger_with_tomatoes", "nuts", "nuts", "lollypop", None]186if dtype == pl.Enum:187dtype = pl.Enum(list({x for x in values if x is not None}))188s = pl.Series("a", values, dtype=dtype)189assert_series_equal(190s.cat.ends_with("pop"), pl.Series("a", [False, False, False, True, None])191)192assert_series_equal(193s.cat.starts_with("nu"), pl.Series("a", [False, True, True, False, None])194)195196with pytest.raises(TypeError, match="'prefix' must be a string; found"):197s.cat.starts_with(None) # type: ignore[arg-type]198199with pytest.raises(TypeError, match="'suffix' must be a string; found"):200s.cat.ends_with(None) # type: ignore[arg-type]201202df = pl.DataFrame({"a": pl.Series(values, dtype=dtype)})203204expected = {205"ends_pop": [False, False, False, True, None],206"starts_ham": [True, False, False, False, None],207}208209assert (210df.select(211pl.col("a").cat.ends_with("pop").alias("ends_pop"),212pl.col("a").cat.starts_with("ham").alias("starts_ham"),213).to_dict(as_series=False)214== expected215)216217with pytest.raises(TypeError, match="'prefix' must be a string; found"):218df.select(pl.col("a").cat.starts_with(None)) # type: ignore[arg-type]219220with pytest.raises(TypeError, match="'suffix' must be a string; found"):221df.select(pl.col("a").cat.ends_with(None)) # type: ignore[arg-type]222223224@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])225def test_cat_slice(dtype: PolarsDataType) -> None:226values = ["foobar", "barfoo", "foobar", "x", None]227if dtype == pl.Enum:228dtype = pl.Enum(list({x for x in values if x is not None}))229df = pl.DataFrame({"a": pl.Series(values, dtype=dtype)})230assert df["a"].cat.slice(-3).to_list() == ["bar", "foo", "bar", "x", None]231assert df.select([pl.col("a").cat.slice(2, 4)])["a"].to_list() == [232"obar",233"rfoo",234"obar",235"",236None,237]238239240def test_cat_order_flag_csv_read_23823() -> None:241data = BytesIO(b"colx,coly\nabc,123\n#not_a_row\nxyz,456")242lf = pl.scan_csv(243source=data,244comment_prefix="#",245schema_overrides={"colx": pl.Categorical},246)247expected = pl.DataFrame(248{"colx": ["abc", "xyz"], "coly": [123, 456]},249schema_overrides={"colx": pl.Categorical},250)251assert_frame_equal(expected, lf.sort("colx", descending=False).collect())252253254