Path: blob/main/py-polars/tests/unit/operations/test_is_in.py
8424 views
from __future__ import annotations12from collections.abc import Collection3from datetime import date4from decimal import Decimal as D5from typing import TYPE_CHECKING67import pytest89import polars as pl10from polars.exceptions import InvalidOperationError11from polars.testing import assert_frame_equal, assert_series_equal1213if TYPE_CHECKING:14from collections.abc import Iterator1516from polars._typing import PolarsDataType171819def test_struct_logical_is_in() -> None:20df1 = pl.DataFrame(21{22"x": pl.date_range(date(2022, 1, 1), date(2022, 1, 7), eager=True),23"y": [0, 4, 6, 2, 3, 4, 5],24}25)26df2 = pl.DataFrame(27{28"x": pl.date_range(date(2022, 1, 3), date(2022, 1, 9), eager=True),29"y": [6, 2, 3, 4, 5, 0, 1],30}31)3233s1 = df1.select(pl.struct(["x", "y"])).to_series()34s2 = df2.select(pl.struct(["x", "y"])).to_series()35assert s1.is_in(s2).to_list() == [False, False, True, True, True, True, True]363738def test_struct_logical_is_in_nonullpropagate() -> None:39s = pl.Series([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3), None])40df1 = pl.DataFrame(41{42"x": s,43"y": [0, 4, 6, None],44}45)46s = pl.Series([date(2022, 2, 1), date(2022, 1, 2), date(2022, 2, 3), None])47df2 = pl.DataFrame(48{49"x": s,50"y": [6, 4, 3, None],51}52)5354# Left has no nulls, right has nulls55s1 = df1.select(pl.struct(["x", "y"])).to_series()56s1 = s1.extend_constant(s1[0], 1)57s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)58assert s1.is_in(s2, nulls_equal=False).to_list() == [59False,60True,61False,62True,63False,64]65assert s1.is_in(s2, nulls_equal=True).to_list() == [66False,67True,68False,69True,70False,71]7273# Left has nulls, right has no nulls74s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)75s2 = df2.select(pl.struct(["x", "y"])).to_series()76s2 = s2.extend_constant(s2[0], 1)77assert s1.is_in(s2, nulls_equal=False).to_list() == [78False,79True,80False,81True,82None,83]84assert s1.is_in(s2, nulls_equal=True).to_list() == [85False,86True,87False,88True,89False,90]9192# Both have nulls93# {None, None} is a valid element unaffected by the missing parameter.94s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)95s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)96assert s1.is_in(s2, nulls_equal=False).to_list() == [97False,98True,99False,100True,101None,102]103assert s1.is_in(s2, nulls_equal=True).to_list() == [104False,105True,106False,107True,108True,109]110111112@pytest.mark.parametrize("nulls_equal", [False, True])113def test_is_in_bool(nulls_equal: bool) -> None:114vals = [True, None]115df = pl.DataFrame({"A": [True, False, None]})116missing_value = True if nulls_equal else None117assert df.select(pl.col("A").is_in(vals, nulls_equal=nulls_equal)).to_dict(118as_series=False119) == {"A": [True, False, missing_value]}120121122def test_is_in_bool_11216() -> None:123s = pl.Series([False]).is_in([False, None])124expected = pl.Series([True])125assert_series_equal(s, expected)126127128@pytest.mark.parametrize("nulls_equal", [False, True])129def test_is_in_empty_list_4559(nulls_equal: bool) -> None:130assert pl.Series(["a"]).is_in([], nulls_equal=nulls_equal).to_list() == [False]131132133def test_is_in_empty_list_4639() -> None:134df = pl.DataFrame({"a": [1, None]})135empty_list: list[int] = []136137result = df.with_columns([pl.col("a").is_in(empty_list).alias("a_in_list")])138expected = pl.DataFrame({"a": [1, None], "a_in_list": [False, None]})139assert_frame_equal(result, expected)140141142def test_is_in_struct() -> None:143df = pl.DataFrame(144{145"struct_elem": [{"a": 1, "b": 11}, {"a": 1, "b": 90}],146"struct_list": [147[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}],148[{"a": 3, "b": 3}],149],150}151)152153assert df.filter(pl.col("struct_elem").is_in("struct_list")).to_dict(154as_series=False155) == {156"struct_elem": [{"a": 1, "b": 11}],157"struct_list": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}]],158}159160161def test_is_in_null_prop() -> None:162assert pl.Series([None], dtype=pl.Float32).is_in(pl.Series([42])).item() is None163assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Float32})).is_in(164pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Float32}))165).to_list() == [False, None]166167assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Boolean})).is_in(168pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Boolean}))169).to_list() == [False, None]170171172def test_is_in_9070() -> None:173assert not pl.Series([1]).is_in(pl.Series([1.99])).item()174175176def test_is_in_large_uint64_21966() -> None:177# https://github.com/pola-rs/polars/issues/21966178# Large integers beyond Float64 precision (2^53) should compare exactly,179# not lose precision by casting to Float64.180181# Original issue: values differing only beyond float64 precision182s = pl.Series([58830407606777880], dtype=pl.UInt64)183assert not s.is_in([58830407606777883]).item()184assert s.is_in([58830407606777880]).item()185186# Values at and beyond the float64 precision boundary (2^53)187boundary = 2**53188s = pl.Series([boundary, boundary + 1, boundary + 2], dtype=pl.UInt64)189assert s.is_in([boundary]).to_list() == [True, False, False]190assert s.is_in([boundary + 1]).to_list() == [False, True, False]191192# UInt64 vs Int64: should use Int128 supertype to preserve precision193val = 2**53 + 1000194s = pl.Series([val], dtype=pl.UInt64)195assert s.is_in(pl.Series([val], dtype=pl.Int64)).item()196assert not s.is_in(pl.Series([val + 1], dtype=pl.Int64)).item()197198# Int64 vs UInt64 (reverse direction)199s = pl.Series([val], dtype=pl.Int64)200assert s.is_in(pl.Series([val], dtype=pl.UInt64)).item()201assert not s.is_in(pl.Series([val + 1], dtype=pl.UInt64)).item()202203# Negative values in signed list vs unsigned series (uses Int128 supertype)204s = pl.Series([100], dtype=pl.UInt64)205assert s.is_in(pl.Series([-1, 100, 200], dtype=pl.Int64)).item()206assert not s.is_in(pl.Series([-1, 99, 200], dtype=pl.Int64)).item()207208# Smaller integer type combinations that have lossless supertypes209s = pl.Series([100, 200], dtype=pl.UInt32)210assert s.is_in(pl.Series([100, 300], dtype=pl.Int32)).to_list() == [True, False]211212s = pl.Series([100, 200], dtype=pl.Int16)213assert s.is_in(pl.Series([100, 300], dtype=pl.UInt16)).to_list() == [True, False]214215# UInt64 max value (no lossless supertype with Int64)216s = pl.Series([2**64 - 1], dtype=pl.UInt64)217assert s.is_in(pl.Series([2**64 - 1], dtype=pl.UInt64)).item()218assert not s.is_in(pl.Series([2**64 - 2], dtype=pl.UInt64)).item()219220# Fallback to try_get_supertype for types without lossless supertype221s = pl.Series([100], dtype=pl.UInt128)222assert s.is_in(pl.Series([100], dtype=pl.Int64)).item()223assert not s.is_in(pl.Series([99], dtype=pl.Int64)).item()224225226def test_is_in_float_list_10764() -> None:227df = pl.DataFrame(228{229"lst": [[1.0, 2.0, 3.0, 4.0, 5.0], [3.14, 5.28]],230"n": [3.0, 2.0],231}232)233assert df.select(pl.col("n").is_in("lst").alias("is_in")).to_dict(234as_series=False235) == {"is_in": [True, False]}236237238def test_is_in_df() -> None:239df = pl.DataFrame({"a": [1, 2, 3]})240assert df.select(pl.col("a").is_in([1, 2]))["a"].to_list() == [True, True, False]241242243def test_is_in_series() -> None:244s = pl.Series(["a", "b", "c"])245246out = s.is_in(["a", "b"])247assert out.to_list() == [True, True, False]248249# Check if empty list is converted to pl.String250out = s.is_in([])251assert out.to_list() == [False] * out.len()252253for x_y_z in (["x", "y", "z"], {"x", "y", "z"}):254out = s.is_in(x_y_z)255assert out.to_list() == [False, False, False]256257df = pl.DataFrame({"a": [1.0, 2.0], "b": [1, 4], "c": ["e", "d"]})258assert df.select(pl.col("a").is_in(pl.col("b"))).to_series().to_list() == [259True,260False,261]262assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height263264with pytest.raises(265InvalidOperationError,266match=r"'is_in' cannot check for List\(String\) values in Int64 data",267):268df.select(pl.col("b").is_in(["x", "x"]))269270# check we don't shallow-copy and accidentally modify 'a' (see: #10072)271a = pl.Series("a", [1, 2])272b = pl.Series("b", [1, 3]).is_in(a)273274assert a.name == "a"275assert_series_equal(b, pl.Series("b", [True, False]))276277278@pytest.mark.parametrize("nulls_equal", [False, True])279def test_is_in_null(nulls_equal: bool) -> None:280# No nulls in right281s = pl.Series([None, None], dtype=pl.Null)282result = s.is_in([1, 2], nulls_equal=nulls_equal)283missing_value = False if nulls_equal else None284expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean)285assert_series_equal(result, expected)286287# Nulls in right288s = pl.Series([None, None], dtype=pl.Null)289result = s.is_in([None, None], nulls_equal=nulls_equal)290missing_value = True if nulls_equal else None291expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean)292assert_series_equal(result, expected)293294295@pytest.mark.parametrize("nulls_equal", [False, True])296def test_is_in_boolean(nulls_equal: bool) -> None:297# Nulls in neither left nor right298s = pl.Series([True, False])299result = s.is_in([True, False], nulls_equal=nulls_equal)300expected = pl.Series([True, True])301assert_series_equal(result, expected)302303# Nulls in left only304s = pl.Series([True, None])305result = s.is_in([False, False], nulls_equal=nulls_equal)306missing_value = False if nulls_equal else None307expected = pl.Series([False, missing_value])308assert_series_equal(result, expected)309310# Nulls in right only311s = pl.Series([True, False])312result = s.is_in([True, None], nulls_equal=nulls_equal)313expected = pl.Series([True, False])314assert_series_equal(result, expected)315316# Nulls in both317s = pl.Series([True, False, None])318result = s.is_in([True, None], nulls_equal=nulls_equal)319missing_value = True if nulls_equal else None320expected = pl.Series([True, False, missing_value])321assert_series_equal(result, expected)322323324@pytest.mark.parametrize("dtype", [pl.List(pl.Boolean), pl.Array(pl.Boolean, 2)])325@pytest.mark.parametrize("nulls_equal", [False, True])326def test_is_in_boolean_list(dtype: PolarsDataType, nulls_equal: bool) -> None:327# Note list is_in does not propagate nulls.328df = pl.DataFrame(329{330"a": [True, False, None, None, None],331"b": pl.Series(332[333[True, False],334[True, True],335[None, True],336[False, True],337[True, True],338],339dtype=dtype,340),341}342)343missing_true = True if nulls_equal else None344missing_false = False if nulls_equal else None345result = df.select(pl.col("a").is_in("b", nulls_equal=nulls_equal))["a"]346expected = pl.Series("a", [True, False, missing_true, missing_false, missing_false])347assert_series_equal(result, expected)348349350def test_is_in_invalid_shape() -> None:351with pytest.raises(InvalidOperationError):352pl.Series("a", [1, 2, 3]).is_in([[], []])353354355def test_is_in_list_rhs() -> None:356assert_series_equal(357pl.Series([1, 2, 3, 4, 5]).is_in(pl.Series([[1], [2, 9], [None], None, None])),358pl.Series([True, True, False, None, None]),359)360361362@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])363def test_is_in_float(dtype: PolarsDataType) -> None:364s = pl.Series([float("nan"), 0.0], dtype=dtype)365result = s.is_in([-0.0, -float("nan")])366expected = pl.Series([True, True], dtype=pl.Boolean)367assert_series_equal(result, expected)368369370@pytest.mark.parametrize(371("df", "matches", "expected_error"),372[373(374pl.DataFrame({"a": [1, 2], "b": [[1.0, 2.5], [3.0, 4.0]]}),375[True, False],376None,377),378(379pl.DataFrame({"a": [2.5, 3.0], "b": [[1, 2], [3, 4]]}),380[False, True],381None,382),383(384pl.DataFrame(385{"a": [None, None], "b": [[1, 2], [3, 4]]},386schema_overrides={"a": pl.Null},387),388[None, None],389None,390),391(392pl.DataFrame({"a": ["1", "2"], "b": [[1, 2], [3, 4]]}),393None,394r"'is_in' cannot check for List\(Int64\) values in String data",395),396(397pl.DataFrame({"a": [date.today(), None], "b": [[1, 2], [3, 4]]}),398None,399r"'is_in' cannot check for List\(Int64\) values in Date data",400),401],402)403def test_is_in_expr_list_series(404df: pl.DataFrame, matches: list[bool] | None, expected_error: str | None405) -> None:406expr_is_in = pl.col("a").is_in(pl.col("b"))407if matches:408assert df.select(expr_is_in).to_series().to_list() == matches409else:410with pytest.raises(InvalidOperationError, match=expected_error):411df.select(expr_is_in)412413414@pytest.mark.parametrize(415("df", "matches"),416[417(418pl.DataFrame({"a": [1, None], "b": [[1.0, 2.5, 4.0], [3.0, 4.0, 5.0]]}),419[True, False],420),421(422pl.DataFrame({"a": [1, None], "b": [[0.0, 2.5, None], [3.0, 4.0, None]]}),423[False, True],424),425(426pl.DataFrame(427{"a": [None, None], "b": [[1, 2], [3, 4]]},428schema_overrides={"a": pl.Null},429),430[False, False],431),432(433pl.DataFrame(434{"a": [None, None], "b": [[1, 2], [3, None]]},435schema_overrides={"a": pl.Null},436),437[False, True],438),439],440)441def test_is_in_expr_list_series_nonullpropagate(442df: pl.DataFrame, matches: list[bool]443) -> None:444expr_is_in = pl.col("a").is_in(pl.col("b"), nulls_equal=True)445assert df.select(expr_is_in).to_series().to_list() == matches446447448@pytest.mark.parametrize("nulls_equal", [False, True])449def test_is_in_null_series(nulls_equal: bool) -> None:450df = pl.DataFrame({"a": ["a", "b", None]})451result = df.select(pl.col("a").is_in([None], nulls_equal=nulls_equal))452missing_value = True if nulls_equal else None453expected = pl.DataFrame({"a": [False, False, missing_value]})454assert_frame_equal(result, expected)455456457def test_is_in_int_range() -> None:458r = pl.int_range(0, 3, eager=False)459out = pl.select(r.is_in([1, 2])).to_series()460assert out.to_list() == [False, True, True]461462r = pl.int_range(0, 3, eager=True) # type: ignore[assignment]463out = r.is_in([1, 2]) # type: ignore[assignment]464assert out.to_list() == [False, True, True]465466467def test_is_in_date_range() -> None:468r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=False)469out = pl.select(r.is_in([date(2023, 1, 2), date(2023, 1, 3)])).to_series()470assert out.to_list() == [False, True, True]471472r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=True) # type: ignore[assignment]473out = r.is_in([date(2023, 1, 2), date(2023, 1, 3)]) # type: ignore[assignment]474assert out.to_list() == [False, True, True]475476477@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])478@pytest.mark.parametrize("nulls_equal", [False, True])479def test_cat_is_in_series(dtype: pl.DataType, nulls_equal: bool) -> None:480s = pl.Series(["a", "b", "c", None], dtype=dtype)481s2 = pl.Series(["b", "c"], dtype=dtype)482missing_value = False if nulls_equal else None483expected = pl.Series([False, True, True, missing_value])484assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected)485486s2_str = s2.cast(pl.String)487assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected)488489490@pytest.mark.parametrize("nulls_equal", [False, True])491def test_cat_is_in_series_non_existent(nulls_equal: bool) -> None:492dtype = pl.Categorical493s = pl.Series(["a", "b", "c", None], dtype=dtype)494s2 = pl.Series(["a", "d", "e"], dtype=dtype)495missing_value = False if nulls_equal else None496expected = pl.Series([True, False, False, missing_value])497assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected)498499s2_str = s2.cast(pl.String)500assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected)501502503@pytest.mark.parametrize(504"nulls_equal",505[False, True],506)507def test_enum_is_in_series_non_existent(nulls_equal: bool) -> None:508dtype = pl.Enum(["a", "b", "c"])509missing_value = False if nulls_equal else None510s = pl.Series(["a", "b", "c", None], dtype=dtype)511s2_str = pl.Series(["a", "d", "e"])512expected = pl.Series([True, False, False, missing_value])513514with pytest.raises(InvalidOperationError):515s.is_in(s2_str, nulls_equal=nulls_equal)516with pytest.raises(InvalidOperationError):517s.is_in(["a", "d", "e"], nulls_equal=nulls_equal)518519out = s.is_in(["a"], nulls_equal=nulls_equal)520assert_series_equal(out, expected)521522523@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])524@pytest.mark.parametrize("nulls_equal", [False, True])525def test_cat_is_in_with_lit_str(dtype: pl.DataType, nulls_equal: bool) -> None:526missing_value = False if nulls_equal else None527s = pl.Series(["a", "b", "c", None], dtype=dtype)528lit = ["b"]529expected = pl.Series([False, True, False, missing_value])530531assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected)532533534@pytest.mark.parametrize("nulls_equal", [False, True])535def test_cat_is_in_with_lit_str_non_existent(nulls_equal: bool) -> None:536dtype = pl.Categorical()537missing_value = False if nulls_equal else None538s = pl.Series(["a", "b", "c", None], dtype=dtype)539lit = ["d"]540expected = pl.Series([False, False, False, missing_value])541542assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected)543544545@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])546def test_cat_is_in_with_lit_str_cache_setup(dtype: pl.DataType) -> None:547# init the global cache548_ = pl.Series(["c", "b", "a"], dtype=dtype)549550assert_series_equal(pl.Series(["a"], dtype=dtype).is_in(["a"]), pl.Series([True]))551assert_series_equal(pl.Series(["b"], dtype=dtype).is_in(["b"]), pl.Series([True]))552assert_series_equal(pl.Series(["c"], dtype=dtype).is_in(["c"]), pl.Series([True]))553554555def test_is_in_with_wildcard_13809() -> None:556out = pl.DataFrame({"A": ["B"]}).select(pl.all().is_in(["C"]))557expected = pl.DataFrame({"A": [False]})558assert_frame_equal(out, expected)559560561@pytest.mark.parametrize(562"dtype",563[564pl.Categorical,565pl.Enum(["a", "b", "c", "d"]),566],567)568def test_cat_is_in_from_str(dtype: pl.DataType) -> None:569s = pl.Series(["c", "c", "b"], dtype=dtype)570571# test local572assert_series_equal(573pl.Series(["a", "d", "b"]).is_in(s),574pl.Series([False, False, True]),575)576577578@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c", "d"])])579def test_cat_list_is_in_from_cat(dtype: pl.DataType) -> None:580df = pl.DataFrame(581[582(["a", "b"], "c"),583(["a", "b"], "a"),584(["a", None], None),585(["a", "c"], None),586(["a"], "d"),587],588schema={"li": pl.List(dtype), "x": dtype},589orient="row",590)591res = df.select(pl.col("li").list.contains(pl.col("x")))592expected_df = pl.DataFrame({"li": [False, True, True, False, False]})593assert_frame_equal(res, expected_df)594595596@pytest.mark.parametrize(597("val", "expected"),598[599("b", [True, False, False, None, True]),600(None, [False, False, True, None, False]),601("e", [False, False, False, None, False]),602],603)604def test_cat_list_is_in_from_cat_single(val: str | None, expected: list[bool]) -> None:605df = pl.Series(606"li",607[["a", "b"], ["a", "c"], ["a", None], None, ["b"]],608dtype=pl.List(pl.Categorical),609).to_frame()610res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.Categorical)))611expected_df = pl.DataFrame({"li": expected})612assert_frame_equal(res, expected_df)613614615def test_cat_list_is_in_from_str() -> None:616df = pl.DataFrame(617[618(["a", "b"], "c"),619(["a", "b"], "a"),620(["a", None], None),621(["a", "c"], None),622(["a"], "d"),623],624schema={"li": pl.List(pl.Categorical), "x": pl.String},625orient="row",626)627res = df.select(pl.col("li").list.contains(pl.col("x")))628expected_df = pl.DataFrame({"li": [False, True, True, False, False]})629assert_frame_equal(res, expected_df)630631632@pytest.mark.parametrize(633("val", "expected"),634[635("b", [True, False, False, None, True]),636(None, [False, False, True, None, False]),637("e", [False, False, False, None, False]),638],639)640def test_cat_list_is_in_from_single_str(val: str | None, expected: list[bool]) -> None:641df = pl.Series(642"li",643[["a", "b"], ["a", "c"], ["a", None], None, ["b"]],644dtype=pl.List(pl.Categorical),645).to_frame()646res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.String)))647expected_df = pl.DataFrame({"li": expected})648assert_frame_equal(res, expected_df)649650651@pytest.mark.parametrize("nulls_equal", [False, True])652def test_is_in_struct_enum_17618(nulls_equal: bool) -> None:653df = pl.DataFrame()654dtype = pl.Enum(categories=["HBS"])655df = df.insert_column(0, pl.Series("category", [], dtype=dtype))656assert df.filter(657pl.struct("category").is_in(658pl.Series(659[{"category": "HBS"}],660dtype=pl.Struct({"category": df["category"].dtype}),661),662nulls_equal=nulls_equal,663)664).shape == (0, 1)665666667@pytest.mark.parametrize("nulls_equal", [False, True])668def test_is_in_decimal(nulls_equal: bool) -> None:669assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(670pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal)671)["a"].to_list() == [True, False, True]672assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(673pl.col("a").is_in([D("0.0"), D("0.1")], nulls_equal=nulls_equal)674)["a"].to_list() == [True, False, True]675assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(676pl.col("a").is_in([1, 0, 2], nulls_equal=nulls_equal)677)["a"].to_list() == [True, False, False]678missing_value = True if nulls_equal else None679assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select(680pl.col("a").is_in([0.0, 0.1, None], nulls_equal=nulls_equal)681)["a"].to_list() == [True, False, missing_value]682missing_value = False if nulls_equal else None683assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select(684pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal)685)["a"].to_list() == [True, False, missing_value]686687688def test_is_in_collection() -> None:689df = pl.DataFrame(690{691"lbl": ["aa", "bb", "cc", "dd", "ee"],692"val": [0, 1, 2, 3, 4],693}694)695696class CustomCollection(Collection[int]):697def __init__(self, vals: Collection[int]) -> None:698super().__init__()699self.vals = vals700701def __contains__(self, x: object) -> bool:702return x in self.vals703704def __iter__(self) -> Iterator[int]:705yield from self.vals706707def __len__(self) -> int:708return len(self.vals)709710for constraint_values in (711{3, 2, 1},712frozenset({3, 2, 1}),713CustomCollection([3, 2, 1]),714):715res = df.filter(pl.col("val").is_in(constraint_values))716assert set(res["lbl"]) == {"bb", "cc", "dd"}717718719@pytest.mark.parametrize("nulls_equal", [False, True])720def test_null_propagate_all_paths(nulls_equal: bool) -> None:721# No nulls in either722s = pl.Series([1, 2, 3])723result = s.is_in([1, 3, 8], nulls_equal=nulls_equal)724expected = pl.Series([True, False, True])725assert_series_equal(result, expected)726727# Nulls in left only728s = pl.Series([1, 2, None])729result = s.is_in([1, 3, 8], nulls_equal=nulls_equal)730missing_value = False if nulls_equal else None731expected = pl.Series([True, False, missing_value])732assert_series_equal(result, expected)733734# Nulls in right only735s = pl.Series([1, 2, 3])736result = s.is_in([1, 3, None], nulls_equal=nulls_equal)737expected = pl.Series([True, False, True])738assert_series_equal(result, expected)739740# Nulls in both741s = pl.Series([1, 2, None])742result = s.is_in([1, 3, None], nulls_equal=nulls_equal)743missing_value = True if nulls_equal else None744expected = pl.Series([True, False, missing_value])745assert_series_equal(result, expected)746747748@pytest.mark.parametrize("nulls_equal", [False, True])749def test_null_propagate_all_paths_cat(nulls_equal: bool) -> None:750# No nulls in either751s = pl.Series(["1", "2", "3"])752result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal)753expected = pl.Series([True, False, True])754assert_series_equal(result, expected)755756# Nulls in left only757s = pl.Series(["1", "2", None])758result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal)759missing_value = False if nulls_equal else None760expected = pl.Series([True, False, missing_value])761assert_series_equal(result, expected)762763# Nulls in right only764s = pl.Series(["1", "2", "3"])765result = s.is_in(["1", "3", None], nulls_equal=nulls_equal)766expected = pl.Series([True, False, True])767assert_series_equal(result, expected)768769# Nulls in both770s = pl.Series(["1", "2", None])771result = s.is_in(["1", "3", None], nulls_equal=nulls_equal)772missing_value = True if nulls_equal else None773expected = pl.Series([True, False, missing_value])774assert_series_equal(result, expected)775776777