Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_is_in.py
8424 views
1
from __future__ import annotations
2
3
from collections.abc import Collection
4
from datetime import date
5
from decimal import Decimal as D
6
from typing import TYPE_CHECKING
7
8
import pytest
9
10
import polars as pl
11
from polars.exceptions import InvalidOperationError
12
from polars.testing import assert_frame_equal, assert_series_equal
13
14
if TYPE_CHECKING:
15
from collections.abc import Iterator
16
17
from polars._typing import PolarsDataType
18
19
20
def test_struct_logical_is_in() -> None:
21
df1 = pl.DataFrame(
22
{
23
"x": pl.date_range(date(2022, 1, 1), date(2022, 1, 7), eager=True),
24
"y": [0, 4, 6, 2, 3, 4, 5],
25
}
26
)
27
df2 = pl.DataFrame(
28
{
29
"x": pl.date_range(date(2022, 1, 3), date(2022, 1, 9), eager=True),
30
"y": [6, 2, 3, 4, 5, 0, 1],
31
}
32
)
33
34
s1 = df1.select(pl.struct(["x", "y"])).to_series()
35
s2 = df2.select(pl.struct(["x", "y"])).to_series()
36
assert s1.is_in(s2).to_list() == [False, False, True, True, True, True, True]
37
38
39
def test_struct_logical_is_in_nonullpropagate() -> None:
40
s = pl.Series([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3), None])
41
df1 = pl.DataFrame(
42
{
43
"x": s,
44
"y": [0, 4, 6, None],
45
}
46
)
47
s = pl.Series([date(2022, 2, 1), date(2022, 1, 2), date(2022, 2, 3), None])
48
df2 = pl.DataFrame(
49
{
50
"x": s,
51
"y": [6, 4, 3, None],
52
}
53
)
54
55
# Left has no nulls, right has nulls
56
s1 = df1.select(pl.struct(["x", "y"])).to_series()
57
s1 = s1.extend_constant(s1[0], 1)
58
s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
59
assert s1.is_in(s2, nulls_equal=False).to_list() == [
60
False,
61
True,
62
False,
63
True,
64
False,
65
]
66
assert s1.is_in(s2, nulls_equal=True).to_list() == [
67
False,
68
True,
69
False,
70
True,
71
False,
72
]
73
74
# Left has nulls, right has no nulls
75
s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
76
s2 = df2.select(pl.struct(["x", "y"])).to_series()
77
s2 = s2.extend_constant(s2[0], 1)
78
assert s1.is_in(s2, nulls_equal=False).to_list() == [
79
False,
80
True,
81
False,
82
True,
83
None,
84
]
85
assert s1.is_in(s2, nulls_equal=True).to_list() == [
86
False,
87
True,
88
False,
89
True,
90
False,
91
]
92
93
# Both have nulls
94
# {None, None} is a valid element unaffected by the missing parameter.
95
s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
96
s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
97
assert s1.is_in(s2, nulls_equal=False).to_list() == [
98
False,
99
True,
100
False,
101
True,
102
None,
103
]
104
assert s1.is_in(s2, nulls_equal=True).to_list() == [
105
False,
106
True,
107
False,
108
True,
109
True,
110
]
111
112
113
@pytest.mark.parametrize("nulls_equal", [False, True])
114
def test_is_in_bool(nulls_equal: bool) -> None:
115
vals = [True, None]
116
df = pl.DataFrame({"A": [True, False, None]})
117
missing_value = True if nulls_equal else None
118
assert df.select(pl.col("A").is_in(vals, nulls_equal=nulls_equal)).to_dict(
119
as_series=False
120
) == {"A": [True, False, missing_value]}
121
122
123
def test_is_in_bool_11216() -> None:
124
s = pl.Series([False]).is_in([False, None])
125
expected = pl.Series([True])
126
assert_series_equal(s, expected)
127
128
129
@pytest.mark.parametrize("nulls_equal", [False, True])
130
def test_is_in_empty_list_4559(nulls_equal: bool) -> None:
131
assert pl.Series(["a"]).is_in([], nulls_equal=nulls_equal).to_list() == [False]
132
133
134
def test_is_in_empty_list_4639() -> None:
135
df = pl.DataFrame({"a": [1, None]})
136
empty_list: list[int] = []
137
138
result = df.with_columns([pl.col("a").is_in(empty_list).alias("a_in_list")])
139
expected = pl.DataFrame({"a": [1, None], "a_in_list": [False, None]})
140
assert_frame_equal(result, expected)
141
142
143
def test_is_in_struct() -> None:
144
df = pl.DataFrame(
145
{
146
"struct_elem": [{"a": 1, "b": 11}, {"a": 1, "b": 90}],
147
"struct_list": [
148
[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}],
149
[{"a": 3, "b": 3}],
150
],
151
}
152
)
153
154
assert df.filter(pl.col("struct_elem").is_in("struct_list")).to_dict(
155
as_series=False
156
) == {
157
"struct_elem": [{"a": 1, "b": 11}],
158
"struct_list": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}]],
159
}
160
161
162
def test_is_in_null_prop() -> None:
163
assert pl.Series([None], dtype=pl.Float32).is_in(pl.Series([42])).item() is None
164
assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Float32})).is_in(
165
pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Float32}))
166
).to_list() == [False, None]
167
168
assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Boolean})).is_in(
169
pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Boolean}))
170
).to_list() == [False, None]
171
172
173
def test_is_in_9070() -> None:
174
assert not pl.Series([1]).is_in(pl.Series([1.99])).item()
175
176
177
def test_is_in_large_uint64_21966() -> None:
178
# https://github.com/pola-rs/polars/issues/21966
179
# Large integers beyond Float64 precision (2^53) should compare exactly,
180
# not lose precision by casting to Float64.
181
182
# Original issue: values differing only beyond float64 precision
183
s = pl.Series([58830407606777880], dtype=pl.UInt64)
184
assert not s.is_in([58830407606777883]).item()
185
assert s.is_in([58830407606777880]).item()
186
187
# Values at and beyond the float64 precision boundary (2^53)
188
boundary = 2**53
189
s = pl.Series([boundary, boundary + 1, boundary + 2], dtype=pl.UInt64)
190
assert s.is_in([boundary]).to_list() == [True, False, False]
191
assert s.is_in([boundary + 1]).to_list() == [False, True, False]
192
193
# UInt64 vs Int64: should use Int128 supertype to preserve precision
194
val = 2**53 + 1000
195
s = pl.Series([val], dtype=pl.UInt64)
196
assert s.is_in(pl.Series([val], dtype=pl.Int64)).item()
197
assert not s.is_in(pl.Series([val + 1], dtype=pl.Int64)).item()
198
199
# Int64 vs UInt64 (reverse direction)
200
s = pl.Series([val], dtype=pl.Int64)
201
assert s.is_in(pl.Series([val], dtype=pl.UInt64)).item()
202
assert not s.is_in(pl.Series([val + 1], dtype=pl.UInt64)).item()
203
204
# Negative values in signed list vs unsigned series (uses Int128 supertype)
205
s = pl.Series([100], dtype=pl.UInt64)
206
assert s.is_in(pl.Series([-1, 100, 200], dtype=pl.Int64)).item()
207
assert not s.is_in(pl.Series([-1, 99, 200], dtype=pl.Int64)).item()
208
209
# Smaller integer type combinations that have lossless supertypes
210
s = pl.Series([100, 200], dtype=pl.UInt32)
211
assert s.is_in(pl.Series([100, 300], dtype=pl.Int32)).to_list() == [True, False]
212
213
s = pl.Series([100, 200], dtype=pl.Int16)
214
assert s.is_in(pl.Series([100, 300], dtype=pl.UInt16)).to_list() == [True, False]
215
216
# UInt64 max value (no lossless supertype with Int64)
217
s = pl.Series([2**64 - 1], dtype=pl.UInt64)
218
assert s.is_in(pl.Series([2**64 - 1], dtype=pl.UInt64)).item()
219
assert not s.is_in(pl.Series([2**64 - 2], dtype=pl.UInt64)).item()
220
221
# Fallback to try_get_supertype for types without lossless supertype
222
s = pl.Series([100], dtype=pl.UInt128)
223
assert s.is_in(pl.Series([100], dtype=pl.Int64)).item()
224
assert not s.is_in(pl.Series([99], dtype=pl.Int64)).item()
225
226
227
def test_is_in_float_list_10764() -> None:
228
df = pl.DataFrame(
229
{
230
"lst": [[1.0, 2.0, 3.0, 4.0, 5.0], [3.14, 5.28]],
231
"n": [3.0, 2.0],
232
}
233
)
234
assert df.select(pl.col("n").is_in("lst").alias("is_in")).to_dict(
235
as_series=False
236
) == {"is_in": [True, False]}
237
238
239
def test_is_in_df() -> None:
240
df = pl.DataFrame({"a": [1, 2, 3]})
241
assert df.select(pl.col("a").is_in([1, 2]))["a"].to_list() == [True, True, False]
242
243
244
def test_is_in_series() -> None:
245
s = pl.Series(["a", "b", "c"])
246
247
out = s.is_in(["a", "b"])
248
assert out.to_list() == [True, True, False]
249
250
# Check if empty list is converted to pl.String
251
out = s.is_in([])
252
assert out.to_list() == [False] * out.len()
253
254
for x_y_z in (["x", "y", "z"], {"x", "y", "z"}):
255
out = s.is_in(x_y_z)
256
assert out.to_list() == [False, False, False]
257
258
df = pl.DataFrame({"a": [1.0, 2.0], "b": [1, 4], "c": ["e", "d"]})
259
assert df.select(pl.col("a").is_in(pl.col("b"))).to_series().to_list() == [
260
True,
261
False,
262
]
263
assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height
264
265
with pytest.raises(
266
InvalidOperationError,
267
match=r"'is_in' cannot check for List\(String\) values in Int64 data",
268
):
269
df.select(pl.col("b").is_in(["x", "x"]))
270
271
# check we don't shallow-copy and accidentally modify 'a' (see: #10072)
272
a = pl.Series("a", [1, 2])
273
b = pl.Series("b", [1, 3]).is_in(a)
274
275
assert a.name == "a"
276
assert_series_equal(b, pl.Series("b", [True, False]))
277
278
279
@pytest.mark.parametrize("nulls_equal", [False, True])
280
def test_is_in_null(nulls_equal: bool) -> None:
281
# No nulls in right
282
s = pl.Series([None, None], dtype=pl.Null)
283
result = s.is_in([1, 2], nulls_equal=nulls_equal)
284
missing_value = False if nulls_equal else None
285
expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean)
286
assert_series_equal(result, expected)
287
288
# Nulls in right
289
s = pl.Series([None, None], dtype=pl.Null)
290
result = s.is_in([None, None], nulls_equal=nulls_equal)
291
missing_value = True if nulls_equal else None
292
expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean)
293
assert_series_equal(result, expected)
294
295
296
@pytest.mark.parametrize("nulls_equal", [False, True])
297
def test_is_in_boolean(nulls_equal: bool) -> None:
298
# Nulls in neither left nor right
299
s = pl.Series([True, False])
300
result = s.is_in([True, False], nulls_equal=nulls_equal)
301
expected = pl.Series([True, True])
302
assert_series_equal(result, expected)
303
304
# Nulls in left only
305
s = pl.Series([True, None])
306
result = s.is_in([False, False], nulls_equal=nulls_equal)
307
missing_value = False if nulls_equal else None
308
expected = pl.Series([False, missing_value])
309
assert_series_equal(result, expected)
310
311
# Nulls in right only
312
s = pl.Series([True, False])
313
result = s.is_in([True, None], nulls_equal=nulls_equal)
314
expected = pl.Series([True, False])
315
assert_series_equal(result, expected)
316
317
# Nulls in both
318
s = pl.Series([True, False, None])
319
result = s.is_in([True, None], nulls_equal=nulls_equal)
320
missing_value = True if nulls_equal else None
321
expected = pl.Series([True, False, missing_value])
322
assert_series_equal(result, expected)
323
324
325
@pytest.mark.parametrize("dtype", [pl.List(pl.Boolean), pl.Array(pl.Boolean, 2)])
326
@pytest.mark.parametrize("nulls_equal", [False, True])
327
def test_is_in_boolean_list(dtype: PolarsDataType, nulls_equal: bool) -> None:
328
# Note list is_in does not propagate nulls.
329
df = pl.DataFrame(
330
{
331
"a": [True, False, None, None, None],
332
"b": pl.Series(
333
[
334
[True, False],
335
[True, True],
336
[None, True],
337
[False, True],
338
[True, True],
339
],
340
dtype=dtype,
341
),
342
}
343
)
344
missing_true = True if nulls_equal else None
345
missing_false = False if nulls_equal else None
346
result = df.select(pl.col("a").is_in("b", nulls_equal=nulls_equal))["a"]
347
expected = pl.Series("a", [True, False, missing_true, missing_false, missing_false])
348
assert_series_equal(result, expected)
349
350
351
def test_is_in_invalid_shape() -> None:
352
with pytest.raises(InvalidOperationError):
353
pl.Series("a", [1, 2, 3]).is_in([[], []])
354
355
356
def test_is_in_list_rhs() -> None:
357
assert_series_equal(
358
pl.Series([1, 2, 3, 4, 5]).is_in(pl.Series([[1], [2, 9], [None], None, None])),
359
pl.Series([True, True, False, None, None]),
360
)
361
362
363
@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
364
def test_is_in_float(dtype: PolarsDataType) -> None:
365
s = pl.Series([float("nan"), 0.0], dtype=dtype)
366
result = s.is_in([-0.0, -float("nan")])
367
expected = pl.Series([True, True], dtype=pl.Boolean)
368
assert_series_equal(result, expected)
369
370
371
@pytest.mark.parametrize(
372
("df", "matches", "expected_error"),
373
[
374
(
375
pl.DataFrame({"a": [1, 2], "b": [[1.0, 2.5], [3.0, 4.0]]}),
376
[True, False],
377
None,
378
),
379
(
380
pl.DataFrame({"a": [2.5, 3.0], "b": [[1, 2], [3, 4]]}),
381
[False, True],
382
None,
383
),
384
(
385
pl.DataFrame(
386
{"a": [None, None], "b": [[1, 2], [3, 4]]},
387
schema_overrides={"a": pl.Null},
388
),
389
[None, None],
390
None,
391
),
392
(
393
pl.DataFrame({"a": ["1", "2"], "b": [[1, 2], [3, 4]]}),
394
None,
395
r"'is_in' cannot check for List\(Int64\) values in String data",
396
),
397
(
398
pl.DataFrame({"a": [date.today(), None], "b": [[1, 2], [3, 4]]}),
399
None,
400
r"'is_in' cannot check for List\(Int64\) values in Date data",
401
),
402
],
403
)
404
def test_is_in_expr_list_series(
405
df: pl.DataFrame, matches: list[bool] | None, expected_error: str | None
406
) -> None:
407
expr_is_in = pl.col("a").is_in(pl.col("b"))
408
if matches:
409
assert df.select(expr_is_in).to_series().to_list() == matches
410
else:
411
with pytest.raises(InvalidOperationError, match=expected_error):
412
df.select(expr_is_in)
413
414
415
@pytest.mark.parametrize(
416
("df", "matches"),
417
[
418
(
419
pl.DataFrame({"a": [1, None], "b": [[1.0, 2.5, 4.0], [3.0, 4.0, 5.0]]}),
420
[True, False],
421
),
422
(
423
pl.DataFrame({"a": [1, None], "b": [[0.0, 2.5, None], [3.0, 4.0, None]]}),
424
[False, True],
425
),
426
(
427
pl.DataFrame(
428
{"a": [None, None], "b": [[1, 2], [3, 4]]},
429
schema_overrides={"a": pl.Null},
430
),
431
[False, False],
432
),
433
(
434
pl.DataFrame(
435
{"a": [None, None], "b": [[1, 2], [3, None]]},
436
schema_overrides={"a": pl.Null},
437
),
438
[False, True],
439
),
440
],
441
)
442
def test_is_in_expr_list_series_nonullpropagate(
443
df: pl.DataFrame, matches: list[bool]
444
) -> None:
445
expr_is_in = pl.col("a").is_in(pl.col("b"), nulls_equal=True)
446
assert df.select(expr_is_in).to_series().to_list() == matches
447
448
449
@pytest.mark.parametrize("nulls_equal", [False, True])
450
def test_is_in_null_series(nulls_equal: bool) -> None:
451
df = pl.DataFrame({"a": ["a", "b", None]})
452
result = df.select(pl.col("a").is_in([None], nulls_equal=nulls_equal))
453
missing_value = True if nulls_equal else None
454
expected = pl.DataFrame({"a": [False, False, missing_value]})
455
assert_frame_equal(result, expected)
456
457
458
def test_is_in_int_range() -> None:
459
r = pl.int_range(0, 3, eager=False)
460
out = pl.select(r.is_in([1, 2])).to_series()
461
assert out.to_list() == [False, True, True]
462
463
r = pl.int_range(0, 3, eager=True) # type: ignore[assignment]
464
out = r.is_in([1, 2]) # type: ignore[assignment]
465
assert out.to_list() == [False, True, True]
466
467
468
def test_is_in_date_range() -> None:
469
r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=False)
470
out = pl.select(r.is_in([date(2023, 1, 2), date(2023, 1, 3)])).to_series()
471
assert out.to_list() == [False, True, True]
472
473
r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=True) # type: ignore[assignment]
474
out = r.is_in([date(2023, 1, 2), date(2023, 1, 3)]) # type: ignore[assignment]
475
assert out.to_list() == [False, True, True]
476
477
478
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])
479
@pytest.mark.parametrize("nulls_equal", [False, True])
480
def test_cat_is_in_series(dtype: pl.DataType, nulls_equal: bool) -> None:
481
s = pl.Series(["a", "b", "c", None], dtype=dtype)
482
s2 = pl.Series(["b", "c"], dtype=dtype)
483
missing_value = False if nulls_equal else None
484
expected = pl.Series([False, True, True, missing_value])
485
assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected)
486
487
s2_str = s2.cast(pl.String)
488
assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected)
489
490
491
@pytest.mark.parametrize("nulls_equal", [False, True])
492
def test_cat_is_in_series_non_existent(nulls_equal: bool) -> None:
493
dtype = pl.Categorical
494
s = pl.Series(["a", "b", "c", None], dtype=dtype)
495
s2 = pl.Series(["a", "d", "e"], dtype=dtype)
496
missing_value = False if nulls_equal else None
497
expected = pl.Series([True, False, False, missing_value])
498
assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected)
499
500
s2_str = s2.cast(pl.String)
501
assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected)
502
503
504
@pytest.mark.parametrize(
505
"nulls_equal",
506
[False, True],
507
)
508
def test_enum_is_in_series_non_existent(nulls_equal: bool) -> None:
509
dtype = pl.Enum(["a", "b", "c"])
510
missing_value = False if nulls_equal else None
511
s = pl.Series(["a", "b", "c", None], dtype=dtype)
512
s2_str = pl.Series(["a", "d", "e"])
513
expected = pl.Series([True, False, False, missing_value])
514
515
with pytest.raises(InvalidOperationError):
516
s.is_in(s2_str, nulls_equal=nulls_equal)
517
with pytest.raises(InvalidOperationError):
518
s.is_in(["a", "d", "e"], nulls_equal=nulls_equal)
519
520
out = s.is_in(["a"], nulls_equal=nulls_equal)
521
assert_series_equal(out, expected)
522
523
524
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])
525
@pytest.mark.parametrize("nulls_equal", [False, True])
526
def test_cat_is_in_with_lit_str(dtype: pl.DataType, nulls_equal: bool) -> None:
527
missing_value = False if nulls_equal else None
528
s = pl.Series(["a", "b", "c", None], dtype=dtype)
529
lit = ["b"]
530
expected = pl.Series([False, True, False, missing_value])
531
532
assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected)
533
534
535
@pytest.mark.parametrize("nulls_equal", [False, True])
536
def test_cat_is_in_with_lit_str_non_existent(nulls_equal: bool) -> None:
537
dtype = pl.Categorical()
538
missing_value = False if nulls_equal else None
539
s = pl.Series(["a", "b", "c", None], dtype=dtype)
540
lit = ["d"]
541
expected = pl.Series([False, False, False, missing_value])
542
543
assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected)
544
545
546
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])
547
def test_cat_is_in_with_lit_str_cache_setup(dtype: pl.DataType) -> None:
548
# init the global cache
549
_ = pl.Series(["c", "b", "a"], dtype=dtype)
550
551
assert_series_equal(pl.Series(["a"], dtype=dtype).is_in(["a"]), pl.Series([True]))
552
assert_series_equal(pl.Series(["b"], dtype=dtype).is_in(["b"]), pl.Series([True]))
553
assert_series_equal(pl.Series(["c"], dtype=dtype).is_in(["c"]), pl.Series([True]))
554
555
556
def test_is_in_with_wildcard_13809() -> None:
557
out = pl.DataFrame({"A": ["B"]}).select(pl.all().is_in(["C"]))
558
expected = pl.DataFrame({"A": [False]})
559
assert_frame_equal(out, expected)
560
561
562
@pytest.mark.parametrize(
563
"dtype",
564
[
565
pl.Categorical,
566
pl.Enum(["a", "b", "c", "d"]),
567
],
568
)
569
def test_cat_is_in_from_str(dtype: pl.DataType) -> None:
570
s = pl.Series(["c", "c", "b"], dtype=dtype)
571
572
# test local
573
assert_series_equal(
574
pl.Series(["a", "d", "b"]).is_in(s),
575
pl.Series([False, False, True]),
576
)
577
578
579
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c", "d"])])
580
def test_cat_list_is_in_from_cat(dtype: pl.DataType) -> None:
581
df = pl.DataFrame(
582
[
583
(["a", "b"], "c"),
584
(["a", "b"], "a"),
585
(["a", None], None),
586
(["a", "c"], None),
587
(["a"], "d"),
588
],
589
schema={"li": pl.List(dtype), "x": dtype},
590
orient="row",
591
)
592
res = df.select(pl.col("li").list.contains(pl.col("x")))
593
expected_df = pl.DataFrame({"li": [False, True, True, False, False]})
594
assert_frame_equal(res, expected_df)
595
596
597
@pytest.mark.parametrize(
598
("val", "expected"),
599
[
600
("b", [True, False, False, None, True]),
601
(None, [False, False, True, None, False]),
602
("e", [False, False, False, None, False]),
603
],
604
)
605
def test_cat_list_is_in_from_cat_single(val: str | None, expected: list[bool]) -> None:
606
df = pl.Series(
607
"li",
608
[["a", "b"], ["a", "c"], ["a", None], None, ["b"]],
609
dtype=pl.List(pl.Categorical),
610
).to_frame()
611
res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.Categorical)))
612
expected_df = pl.DataFrame({"li": expected})
613
assert_frame_equal(res, expected_df)
614
615
616
def test_cat_list_is_in_from_str() -> None:
617
df = pl.DataFrame(
618
[
619
(["a", "b"], "c"),
620
(["a", "b"], "a"),
621
(["a", None], None),
622
(["a", "c"], None),
623
(["a"], "d"),
624
],
625
schema={"li": pl.List(pl.Categorical), "x": pl.String},
626
orient="row",
627
)
628
res = df.select(pl.col("li").list.contains(pl.col("x")))
629
expected_df = pl.DataFrame({"li": [False, True, True, False, False]})
630
assert_frame_equal(res, expected_df)
631
632
633
@pytest.mark.parametrize(
634
("val", "expected"),
635
[
636
("b", [True, False, False, None, True]),
637
(None, [False, False, True, None, False]),
638
("e", [False, False, False, None, False]),
639
],
640
)
641
def test_cat_list_is_in_from_single_str(val: str | None, expected: list[bool]) -> None:
642
df = pl.Series(
643
"li",
644
[["a", "b"], ["a", "c"], ["a", None], None, ["b"]],
645
dtype=pl.List(pl.Categorical),
646
).to_frame()
647
res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.String)))
648
expected_df = pl.DataFrame({"li": expected})
649
assert_frame_equal(res, expected_df)
650
651
652
@pytest.mark.parametrize("nulls_equal", [False, True])
653
def test_is_in_struct_enum_17618(nulls_equal: bool) -> None:
654
df = pl.DataFrame()
655
dtype = pl.Enum(categories=["HBS"])
656
df = df.insert_column(0, pl.Series("category", [], dtype=dtype))
657
assert df.filter(
658
pl.struct("category").is_in(
659
pl.Series(
660
[{"category": "HBS"}],
661
dtype=pl.Struct({"category": df["category"].dtype}),
662
),
663
nulls_equal=nulls_equal,
664
)
665
).shape == (0, 1)
666
667
668
@pytest.mark.parametrize("nulls_equal", [False, True])
669
def test_is_in_decimal(nulls_equal: bool) -> None:
670
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
671
pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal)
672
)["a"].to_list() == [True, False, True]
673
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
674
pl.col("a").is_in([D("0.0"), D("0.1")], nulls_equal=nulls_equal)
675
)["a"].to_list() == [True, False, True]
676
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
677
pl.col("a").is_in([1, 0, 2], nulls_equal=nulls_equal)
678
)["a"].to_list() == [True, False, False]
679
missing_value = True if nulls_equal else None
680
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select(
681
pl.col("a").is_in([0.0, 0.1, None], nulls_equal=nulls_equal)
682
)["a"].to_list() == [True, False, missing_value]
683
missing_value = False if nulls_equal else None
684
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select(
685
pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal)
686
)["a"].to_list() == [True, False, missing_value]
687
688
689
def test_is_in_collection() -> None:
690
df = pl.DataFrame(
691
{
692
"lbl": ["aa", "bb", "cc", "dd", "ee"],
693
"val": [0, 1, 2, 3, 4],
694
}
695
)
696
697
class CustomCollection(Collection[int]):
698
def __init__(self, vals: Collection[int]) -> None:
699
super().__init__()
700
self.vals = vals
701
702
def __contains__(self, x: object) -> bool:
703
return x in self.vals
704
705
def __iter__(self) -> Iterator[int]:
706
yield from self.vals
707
708
def __len__(self) -> int:
709
return len(self.vals)
710
711
for constraint_values in (
712
{3, 2, 1},
713
frozenset({3, 2, 1}),
714
CustomCollection([3, 2, 1]),
715
):
716
res = df.filter(pl.col("val").is_in(constraint_values))
717
assert set(res["lbl"]) == {"bb", "cc", "dd"}
718
719
720
@pytest.mark.parametrize("nulls_equal", [False, True])
721
def test_null_propagate_all_paths(nulls_equal: bool) -> None:
722
# No nulls in either
723
s = pl.Series([1, 2, 3])
724
result = s.is_in([1, 3, 8], nulls_equal=nulls_equal)
725
expected = pl.Series([True, False, True])
726
assert_series_equal(result, expected)
727
728
# Nulls in left only
729
s = pl.Series([1, 2, None])
730
result = s.is_in([1, 3, 8], nulls_equal=nulls_equal)
731
missing_value = False if nulls_equal else None
732
expected = pl.Series([True, False, missing_value])
733
assert_series_equal(result, expected)
734
735
# Nulls in right only
736
s = pl.Series([1, 2, 3])
737
result = s.is_in([1, 3, None], nulls_equal=nulls_equal)
738
expected = pl.Series([True, False, True])
739
assert_series_equal(result, expected)
740
741
# Nulls in both
742
s = pl.Series([1, 2, None])
743
result = s.is_in([1, 3, None], nulls_equal=nulls_equal)
744
missing_value = True if nulls_equal else None
745
expected = pl.Series([True, False, missing_value])
746
assert_series_equal(result, expected)
747
748
749
@pytest.mark.parametrize("nulls_equal", [False, True])
750
def test_null_propagate_all_paths_cat(nulls_equal: bool) -> None:
751
# No nulls in either
752
s = pl.Series(["1", "2", "3"])
753
result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal)
754
expected = pl.Series([True, False, True])
755
assert_series_equal(result, expected)
756
757
# Nulls in left only
758
s = pl.Series(["1", "2", None])
759
result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal)
760
missing_value = False if nulls_equal else None
761
expected = pl.Series([True, False, missing_value])
762
assert_series_equal(result, expected)
763
764
# Nulls in right only
765
s = pl.Series(["1", "2", "3"])
766
result = s.is_in(["1", "3", None], nulls_equal=nulls_equal)
767
expected = pl.Series([True, False, True])
768
assert_series_equal(result, expected)
769
770
# Nulls in both
771
s = pl.Series(["1", "2", None])
772
result = s.is_in(["1", "3", None], nulls_equal=nulls_equal)
773
missing_value = True if nulls_equal else None
774
expected = pl.Series([True, False, missing_value])
775
assert_series_equal(result, expected)
776
777