Path: blob/main/py-polars/tests/unit/operations/namespaces/test_categorical.py
6940 views
from __future__ import annotations12from io import BytesIO3from typing import TYPE_CHECKING45import pytest67import polars as pl8from polars.testing import assert_frame_equal, assert_series_equal910if TYPE_CHECKING:11from polars._typing import PolarsDataType121314def test_categorical_lexical_sort() -> None:15df = pl.DataFrame(16{"cats": ["z", "z", "k", "a", "b"], "vals": [3, 1, 2, 2, 3]}17).with_columns(18pl.col("cats").cast(pl.Categorical("lexical")),19)2021out = df.sort(["cats"])22assert out["cats"].dtype == pl.Categorical23expected = pl.DataFrame(24{"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 3, 1]}25)26assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)27out = df.sort(["cats", "vals"])28expected = pl.DataFrame(29{"cats": ["a", "b", "k", "z", "z"], "vals": [2, 3, 2, 1, 3]}30)31assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)32out = df.sort(["vals", "cats"])3334expected = pl.DataFrame(35{"cats": ["z", "a", "k", "b", "z"], "vals": [1, 2, 2, 3, 3]}36)37assert_frame_equal(out.with_columns(pl.col("cats").cast(pl.String)), expected)3839s = pl.Series(["a", "c", "a", "b", "a"], dtype=pl.Categorical("lexical"))40assert s.sort().cast(pl.String).to_list() == [41"a",42"a",43"a",44"b",45"c",46]474849def test_categorical_lexical_ordering_after_concat() -> None:50ldf1 = (51pl.DataFrame([pl.Series("key1", [8, 5]), pl.Series("key2", ["fox", "baz"])])52.lazy()53.with_columns(pl.col("key2").cast(pl.Categorical("lexical")))54)55ldf2 = (56pl.DataFrame(57[pl.Series("key1", [6, 8, 6]), pl.Series("key2", ["fox", "foo", "bar"])]58)59.lazy()60.with_columns(pl.col("key2").cast(pl.Categorical("lexical")))61)62df = pl.concat([ldf1, ldf2]).select(pl.col("key2")).collect()6364assert df.sort("key2").to_dict(as_series=False) == {65"key2": ["bar", "baz", "foo", "fox", "fox"]66}676869def test_sort_categoricals_6014_lexical() -> None:70# create lexically-ordered categorical71df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns(72pl.col("key").cast(pl.Categorical("lexical"))73)7475out = df.sort("key")76assert out.to_dict(as_series=False) == {"key": ["aaa", "bbb", "ccc"]}777879def test_categorical_get_categories() -> None:80s = pl.Series("cats", ["foo", "bar", "foo", "foo", "ham"], dtype=pl.Categorical)81assert set(s.cat.get_categories().to_list()) >= {"foo", "bar", "ham"}828384def test_cat_to_local() -> None:85s = pl.Series(["a", "b", "a"], dtype=pl.Categorical)86assert_series_equal(s, s.cat.to_local())878889def test_cat_uses_lexical_ordering() -> None:90s = pl.Series(["a", "b", None, "b"]).cast(pl.Categorical)91assert s.cat.uses_lexical_ordering()9293s = s.cast(pl.Categorical("lexical"))94assert s.cat.uses_lexical_ordering()9596with pytest.warns(DeprecationWarning):97s = s.cast(pl.Categorical("physical")) # Deprecated.98assert s.cat.uses_lexical_ordering()99100101@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])102def test_cat_len_bytes(dtype: PolarsDataType) -> None:103# test Series104values = ["Café", None, "Café", "345", "東京"]105if dtype == pl.Enum:106dtype = pl.Enum(list({x for x in values if x is not None}))107s = pl.Series("a", values, dtype=dtype)108result = s.cat.len_bytes()109expected = pl.Series("a", [5, None, 5, 3, 6], dtype=pl.UInt32)110assert_series_equal(result, expected)111112# test DataFrame expr113df = pl.DataFrame(s)114result_df = df.select(pl.col("a").cat.len_bytes())115expected_df = pl.DataFrame(expected)116assert_frame_equal(result_df, expected_df)117118# test LazyFrame expr119result_lf = df.lazy().select(pl.col("a").cat.len_bytes()).collect()120assert_frame_equal(result_lf, expected_df)121122# test GroupBy123result_df = (124pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)})125.group_by("key", maintain_order=True)126.agg(pl.col("value").cat.len_bytes().alias("len_bytes"))127.explode("len_bytes")128.collect()129)130expected_df = pl.DataFrame(131{132"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],133"len_bytes": pl.Series(134[5, None, 5, 3, 6, 5, None, 5, 3, 6], dtype=pl.get_index_type()135),136}137)138assert_frame_equal(result_df, expected_df)139140141@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])142def test_cat_len_chars(dtype: PolarsDataType) -> None:143values = ["Café", None, "Café", "345", "東京"]144if dtype == pl.Enum:145dtype = pl.Enum(list({x for x in values if x is not None}))146# test Series147s = pl.Series("a", values, dtype=dtype)148result = s.cat.len_chars()149expected = pl.Series("a", [4, None, 4, 3, 2], dtype=pl.UInt32)150assert_series_equal(result, expected)151152# test DataFrame expr153df = pl.DataFrame(s)154result_df = df.select(pl.col("a").cat.len_chars())155expected_df = pl.DataFrame(expected)156assert_frame_equal(result_df, expected_df)157158# test LazyFrame expr159result_lf = df.lazy().select(pl.col("a").cat.len_chars()).collect()160assert_frame_equal(result_lf, expected_df)161162# test GroupBy163result_df = (164pl.LazyFrame({"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "value": s.extend(s)})165.group_by("key", maintain_order=True)166.agg(pl.col("value").cat.len_chars().alias("len_bytes"))167.explode("len_bytes")168.collect()169)170expected_df = pl.DataFrame(171{172"key": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2],173"len_bytes": pl.Series(174[4, None, 4, 3, 2, 4, None, 4, 3, 2], dtype=pl.get_index_type()175),176}177)178assert_frame_equal(result_df, expected_df)179180181@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])182def test_starts_ends_with(dtype: PolarsDataType) -> None:183values = ["hamburger_with_tomatoes", "nuts", "nuts", "lollypop", None]184if dtype == pl.Enum:185dtype = pl.Enum(list({x for x in values if x is not None}))186s = pl.Series("a", values, dtype=dtype)187assert_series_equal(188s.cat.ends_with("pop"), pl.Series("a", [False, False, False, True, None])189)190assert_series_equal(191s.cat.starts_with("nu"), pl.Series("a", [False, True, True, False, None])192)193194with pytest.raises(TypeError, match="'prefix' must be a string; found"):195s.cat.starts_with(None) # type: ignore[arg-type]196197with pytest.raises(TypeError, match="'suffix' must be a string; found"):198s.cat.ends_with(None) # type: ignore[arg-type]199200df = pl.DataFrame({"a": pl.Series(values, dtype=dtype)})201202expected = {203"ends_pop": [False, False, False, True, None],204"starts_ham": [True, False, False, False, None],205}206207assert (208df.select(209pl.col("a").cat.ends_with("pop").alias("ends_pop"),210pl.col("a").cat.starts_with("ham").alias("starts_ham"),211).to_dict(as_series=False)212== expected213)214215with pytest.raises(TypeError, match="'prefix' must be a string; found"):216df.select(pl.col("a").cat.starts_with(None)) # type: ignore[arg-type]217218with pytest.raises(TypeError, match="'suffix' must be a string; found"):219df.select(pl.col("a").cat.ends_with(None)) # type: ignore[arg-type]220221222@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum])223def test_cat_slice(dtype: PolarsDataType) -> None:224values = ["foobar", "barfoo", "foobar", "x", None]225if dtype == pl.Enum:226dtype = pl.Enum(list({x for x in values if x is not None}))227df = pl.DataFrame({"a": pl.Series(values, dtype=dtype)})228assert df["a"].cat.slice(-3).to_list() == ["bar", "foo", "bar", "x", None]229assert df.select([pl.col("a").cat.slice(2, 4)])["a"].to_list() == [230"obar",231"rfoo",232"obar",233"",234None,235]236237238def test_cat_order_flag_csv_read_23823() -> None:239data = BytesIO(b"colx,coly\nabc,123\n#not_a_row\nxyz,456")240lf = pl.scan_csv(241source=data,242comment_prefix="#",243schema_overrides={"colx": pl.Categorical},244)245expected = pl.DataFrame(246{"colx": ["abc", "xyz"], "coly": [123, 456]},247schema_overrides={"colx": pl.Categorical},248)249assert_frame_equal(expected, lf.sort("colx", descending=False).collect())250251252