Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_is_in.py
6939 views
1
from __future__ import annotations
2
3
from collections.abc import Collection
4
from datetime import date
5
from decimal import Decimal as D
6
from typing import TYPE_CHECKING
7
8
import pytest
9
10
import polars as pl
11
from polars.exceptions import InvalidOperationError
12
from polars.testing import assert_frame_equal, assert_series_equal
13
14
if TYPE_CHECKING:
15
from collections.abc import Iterator
16
17
from polars._typing import PolarsDataType
18
19
20
def test_struct_logical_is_in() -> None:
21
df1 = pl.DataFrame(
22
{
23
"x": pl.date_range(date(2022, 1, 1), date(2022, 1, 7), eager=True),
24
"y": [0, 4, 6, 2, 3, 4, 5],
25
}
26
)
27
df2 = pl.DataFrame(
28
{
29
"x": pl.date_range(date(2022, 1, 3), date(2022, 1, 9), eager=True),
30
"y": [6, 2, 3, 4, 5, 0, 1],
31
}
32
)
33
34
s1 = df1.select(pl.struct(["x", "y"])).to_series()
35
s2 = df2.select(pl.struct(["x", "y"])).to_series()
36
assert s1.is_in(s2).to_list() == [False, False, True, True, True, True, True]
37
38
39
def test_struct_logical_is_in_nonullpropagate() -> None:
40
s = pl.Series([date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3), None])
41
df1 = pl.DataFrame(
42
{
43
"x": s,
44
"y": [0, 4, 6, None],
45
}
46
)
47
s = pl.Series([date(2022, 2, 1), date(2022, 1, 2), date(2022, 2, 3), None])
48
df2 = pl.DataFrame(
49
{
50
"x": s,
51
"y": [6, 4, 3, None],
52
}
53
)
54
55
# Left has no nulls, right has nulls
56
s1 = df1.select(pl.struct(["x", "y"])).to_series()
57
s1 = s1.extend_constant(s1[0], 1)
58
s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
59
assert s1.is_in(s2, nulls_equal=False).to_list() == [
60
False,
61
True,
62
False,
63
True,
64
False,
65
]
66
assert s1.is_in(s2, nulls_equal=True).to_list() == [
67
False,
68
True,
69
False,
70
True,
71
False,
72
]
73
74
# Left has nulls, right has no nulls
75
s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
76
s2 = df2.select(pl.struct(["x", "y"])).to_series()
77
s2 = s2.extend_constant(s2[0], 1)
78
assert s1.is_in(s2, nulls_equal=False).to_list() == [
79
False,
80
True,
81
False,
82
True,
83
None,
84
]
85
assert s1.is_in(s2, nulls_equal=True).to_list() == [
86
False,
87
True,
88
False,
89
True,
90
False,
91
]
92
93
# Both have nulls
94
# {None, None} is a valid element unaffected by the missing parameter.
95
s1 = df1.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
96
s2 = df2.select(pl.struct(["x", "y"])).to_series().extend_constant(None, 1)
97
assert s1.is_in(s2, nulls_equal=False).to_list() == [
98
False,
99
True,
100
False,
101
True,
102
None,
103
]
104
assert s1.is_in(s2, nulls_equal=True).to_list() == [
105
False,
106
True,
107
False,
108
True,
109
True,
110
]
111
112
113
@pytest.mark.parametrize("nulls_equal", [False, True])
114
def test_is_in_bool(nulls_equal: bool) -> None:
115
vals = [True, None]
116
df = pl.DataFrame({"A": [True, False, None]})
117
missing_value = True if nulls_equal else None
118
assert df.select(pl.col("A").is_in(vals, nulls_equal=nulls_equal)).to_dict(
119
as_series=False
120
) == {"A": [True, False, missing_value]}
121
122
123
def test_is_in_bool_11216() -> None:
124
s = pl.Series([False]).is_in([False, None])
125
expected = pl.Series([True])
126
assert_series_equal(s, expected)
127
128
129
@pytest.mark.parametrize("nulls_equal", [False, True])
130
def test_is_in_empty_list_4559(nulls_equal: bool) -> None:
131
assert pl.Series(["a"]).is_in([], nulls_equal=nulls_equal).to_list() == [False]
132
133
134
def test_is_in_empty_list_4639() -> None:
135
df = pl.DataFrame({"a": [1, None]})
136
empty_list: list[int] = []
137
138
result = df.with_columns([pl.col("a").is_in(empty_list).alias("a_in_list")])
139
expected = pl.DataFrame({"a": [1, None], "a_in_list": [False, None]})
140
assert_frame_equal(result, expected)
141
142
143
def test_is_in_struct() -> None:
144
df = pl.DataFrame(
145
{
146
"struct_elem": [{"a": 1, "b": 11}, {"a": 1, "b": 90}],
147
"struct_list": [
148
[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}],
149
[{"a": 3, "b": 3}],
150
],
151
}
152
)
153
154
assert df.filter(pl.col("struct_elem").is_in("struct_list")).to_dict(
155
as_series=False
156
) == {
157
"struct_elem": [{"a": 1, "b": 11}],
158
"struct_list": [[{"a": 1, "b": 11}, {"a": 2, "b": 12}, {"a": 3, "b": 13}]],
159
}
160
161
162
def test_is_in_null_prop() -> None:
163
assert pl.Series([None], dtype=pl.Float32).is_in(pl.Series([42])).item() is None
164
assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Float32})).is_in(
165
pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Float32}))
166
).to_list() == [False, None]
167
168
assert pl.Series([{"a": None}, None], dtype=pl.Struct({"a": pl.Boolean})).is_in(
169
pl.Series([{"a": 42}], dtype=pl.Struct({"a": pl.Boolean}))
170
).to_list() == [False, None]
171
172
173
def test_is_in_9070() -> None:
174
assert not pl.Series([1]).is_in(pl.Series([1.99])).item()
175
176
177
def test_is_in_float_list_10764() -> None:
178
df = pl.DataFrame(
179
{
180
"lst": [[1.0, 2.0, 3.0, 4.0, 5.0], [3.14, 5.28]],
181
"n": [3.0, 2.0],
182
}
183
)
184
assert df.select(pl.col("n").is_in("lst").alias("is_in")).to_dict(
185
as_series=False
186
) == {"is_in": [True, False]}
187
188
189
def test_is_in_df() -> None:
190
df = pl.DataFrame({"a": [1, 2, 3]})
191
assert df.select(pl.col("a").is_in([1, 2]))["a"].to_list() == [True, True, False]
192
193
194
def test_is_in_series() -> None:
195
s = pl.Series(["a", "b", "c"])
196
197
out = s.is_in(["a", "b"])
198
assert out.to_list() == [True, True, False]
199
200
# Check if empty list is converted to pl.String
201
out = s.is_in([])
202
assert out.to_list() == [False] * out.len()
203
204
for x_y_z in (["x", "y", "z"], {"x", "y", "z"}):
205
out = s.is_in(x_y_z)
206
assert out.to_list() == [False, False, False]
207
208
df = pl.DataFrame({"a": [1.0, 2.0], "b": [1, 4], "c": ["e", "d"]})
209
assert df.select(pl.col("a").is_in(pl.col("b"))).to_series().to_list() == [
210
True,
211
False,
212
]
213
assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height
214
215
with pytest.raises(
216
InvalidOperationError,
217
match=r"'is_in' cannot check for List\(String\) values in Int64 data",
218
):
219
df.select(pl.col("b").is_in(["x", "x"]))
220
221
# check we don't shallow-copy and accidentally modify 'a' (see: #10072)
222
a = pl.Series("a", [1, 2])
223
b = pl.Series("b", [1, 3]).is_in(a)
224
225
assert a.name == "a"
226
assert_series_equal(b, pl.Series("b", [True, False]))
227
228
229
@pytest.mark.parametrize("nulls_equal", [False, True])
230
def test_is_in_null(nulls_equal: bool) -> None:
231
# No nulls in right
232
s = pl.Series([None, None], dtype=pl.Null)
233
result = s.is_in([1, 2], nulls_equal=nulls_equal)
234
missing_value = False if nulls_equal else None
235
expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean)
236
assert_series_equal(result, expected)
237
238
# Nulls in right
239
s = pl.Series([None, None], dtype=pl.Null)
240
result = s.is_in([None, None], nulls_equal=nulls_equal)
241
missing_value = True if nulls_equal else None
242
expected = pl.Series([missing_value, missing_value], dtype=pl.Boolean)
243
assert_series_equal(result, expected)
244
245
246
@pytest.mark.parametrize("nulls_equal", [False, True])
247
def test_is_in_boolean(nulls_equal: bool) -> None:
248
# Nulls in neither left nor right
249
s = pl.Series([True, False])
250
result = s.is_in([True, False], nulls_equal=nulls_equal)
251
expected = pl.Series([True, True])
252
assert_series_equal(result, expected)
253
254
# Nulls in left only
255
s = pl.Series([True, None])
256
result = s.is_in([False, False], nulls_equal=nulls_equal)
257
missing_value = False if nulls_equal else None
258
expected = pl.Series([False, missing_value])
259
assert_series_equal(result, expected)
260
261
# Nulls in right only
262
s = pl.Series([True, False])
263
result = s.is_in([True, None], nulls_equal=nulls_equal)
264
expected = pl.Series([True, False])
265
assert_series_equal(result, expected)
266
267
# Nulls in both
268
s = pl.Series([True, False, None])
269
result = s.is_in([True, None], nulls_equal=nulls_equal)
270
missing_value = True if nulls_equal else None
271
expected = pl.Series([True, False, missing_value])
272
assert_series_equal(result, expected)
273
274
275
@pytest.mark.parametrize("dtype", [pl.List(pl.Boolean), pl.Array(pl.Boolean, 2)])
276
@pytest.mark.parametrize("nulls_equal", [False, True])
277
def test_is_in_boolean_list(dtype: PolarsDataType, nulls_equal: bool) -> None:
278
# Note list is_in does not propagate nulls.
279
df = pl.DataFrame(
280
{
281
"a": [True, False, None, None, None],
282
"b": pl.Series(
283
[
284
[True, False],
285
[True, True],
286
[None, True],
287
[False, True],
288
[True, True],
289
],
290
dtype=dtype,
291
),
292
}
293
)
294
missing_true = True if nulls_equal else None
295
missing_false = False if nulls_equal else None
296
result = df.select(pl.col("a").is_in("b", nulls_equal=nulls_equal))["a"]
297
expected = pl.Series("a", [True, False, missing_true, missing_false, missing_false])
298
assert_series_equal(result, expected)
299
300
301
def test_is_in_invalid_shape() -> None:
302
with pytest.raises(InvalidOperationError):
303
pl.Series("a", [1, 2, 3]).is_in([[], []])
304
305
306
def test_is_in_list_rhs() -> None:
307
assert_series_equal(
308
pl.Series([1, 2, 3, 4, 5]).is_in(pl.Series([[1], [2, 9], [None], None, None])),
309
pl.Series([True, True, False, None, None]),
310
)
311
312
313
@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
314
def test_is_in_float(dtype: PolarsDataType) -> None:
315
s = pl.Series([float("nan"), 0.0], dtype=dtype)
316
result = s.is_in([-0.0, -float("nan")])
317
expected = pl.Series([True, True], dtype=pl.Boolean)
318
assert_series_equal(result, expected)
319
320
321
@pytest.mark.parametrize(
322
("df", "matches", "expected_error"),
323
[
324
(
325
pl.DataFrame({"a": [1, 2], "b": [[1.0, 2.5], [3.0, 4.0]]}),
326
[True, False],
327
None,
328
),
329
(
330
pl.DataFrame({"a": [2.5, 3.0], "b": [[1, 2], [3, 4]]}),
331
[False, True],
332
None,
333
),
334
(
335
pl.DataFrame(
336
{"a": [None, None], "b": [[1, 2], [3, 4]]},
337
schema_overrides={"a": pl.Null},
338
),
339
[None, None],
340
None,
341
),
342
(
343
pl.DataFrame({"a": ["1", "2"], "b": [[1, 2], [3, 4]]}),
344
None,
345
r"'is_in' cannot check for List\(Int64\) values in String data",
346
),
347
(
348
pl.DataFrame({"a": [date.today(), None], "b": [[1, 2], [3, 4]]}),
349
None,
350
r"'is_in' cannot check for List\(Int64\) values in Date data",
351
),
352
],
353
)
354
def test_is_in_expr_list_series(
355
df: pl.DataFrame, matches: list[bool] | None, expected_error: str | None
356
) -> None:
357
expr_is_in = pl.col("a").is_in(pl.col("b"))
358
if matches:
359
assert df.select(expr_is_in).to_series().to_list() == matches
360
else:
361
with pytest.raises(InvalidOperationError, match=expected_error):
362
df.select(expr_is_in)
363
364
365
@pytest.mark.parametrize(
366
("df", "matches"),
367
[
368
(
369
pl.DataFrame({"a": [1, None], "b": [[1.0, 2.5, 4.0], [3.0, 4.0, 5.0]]}),
370
[True, False],
371
),
372
(
373
pl.DataFrame({"a": [1, None], "b": [[0.0, 2.5, None], [3.0, 4.0, None]]}),
374
[False, True],
375
),
376
(
377
pl.DataFrame(
378
{"a": [None, None], "b": [[1, 2], [3, 4]]},
379
schema_overrides={"a": pl.Null},
380
),
381
[False, False],
382
),
383
(
384
pl.DataFrame(
385
{"a": [None, None], "b": [[1, 2], [3, None]]},
386
schema_overrides={"a": pl.Null},
387
),
388
[False, True],
389
),
390
],
391
)
392
def test_is_in_expr_list_series_nonullpropagate(
393
df: pl.DataFrame, matches: list[bool]
394
) -> None:
395
expr_is_in = pl.col("a").is_in(pl.col("b"), nulls_equal=True)
396
assert df.select(expr_is_in).to_series().to_list() == matches
397
398
399
@pytest.mark.parametrize("nulls_equal", [False, True])
400
def test_is_in_null_series(nulls_equal: bool) -> None:
401
df = pl.DataFrame({"a": ["a", "b", None]})
402
result = df.select(pl.col("a").is_in([None], nulls_equal=nulls_equal))
403
missing_value = True if nulls_equal else None
404
expected = pl.DataFrame({"a": [False, False, missing_value]})
405
assert_frame_equal(result, expected)
406
407
408
def test_is_in_int_range() -> None:
409
r = pl.int_range(0, 3, eager=False)
410
out = pl.select(r.is_in([1, 2])).to_series()
411
assert out.to_list() == [False, True, True]
412
413
r = pl.int_range(0, 3, eager=True) # type: ignore[assignment]
414
out = r.is_in([1, 2]) # type: ignore[assignment]
415
assert out.to_list() == [False, True, True]
416
417
418
def test_is_in_date_range() -> None:
419
r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=False)
420
out = pl.select(r.is_in([date(2023, 1, 2), date(2023, 1, 3)])).to_series()
421
assert out.to_list() == [False, True, True]
422
423
r = pl.date_range(date(2023, 1, 1), date(2023, 1, 3), eager=True) # type: ignore[assignment]
424
out = r.is_in([date(2023, 1, 2), date(2023, 1, 3)]) # type: ignore[assignment]
425
assert out.to_list() == [False, True, True]
426
427
428
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])
429
@pytest.mark.parametrize("nulls_equal", [False, True])
430
def test_cat_is_in_series(dtype: pl.DataType, nulls_equal: bool) -> None:
431
s = pl.Series(["a", "b", "c", None], dtype=dtype)
432
s2 = pl.Series(["b", "c"], dtype=dtype)
433
missing_value = False if nulls_equal else None
434
expected = pl.Series([False, True, True, missing_value])
435
assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected)
436
437
s2_str = s2.cast(pl.String)
438
assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected)
439
440
441
@pytest.mark.parametrize("nulls_equal", [False, True])
442
def test_cat_is_in_series_non_existent(nulls_equal: bool) -> None:
443
dtype = pl.Categorical
444
s = pl.Series(["a", "b", "c", None], dtype=dtype)
445
s2 = pl.Series(["a", "d", "e"], dtype=dtype)
446
missing_value = False if nulls_equal else None
447
expected = pl.Series([True, False, False, missing_value])
448
assert_series_equal(s.is_in(s2, nulls_equal=nulls_equal), expected)
449
450
s2_str = s2.cast(pl.String)
451
assert_series_equal(s.is_in(s2_str, nulls_equal=nulls_equal), expected)
452
453
454
@pytest.mark.parametrize(
455
"nulls_equal",
456
[False, True],
457
)
458
def test_enum_is_in_series_non_existent(nulls_equal: bool) -> None:
459
dtype = pl.Enum(["a", "b", "c"])
460
missing_value = False if nulls_equal else None
461
s = pl.Series(["a", "b", "c", None], dtype=dtype)
462
s2_str = pl.Series(["a", "d", "e"])
463
expected = pl.Series([True, False, False, missing_value])
464
465
with pytest.raises(InvalidOperationError):
466
s.is_in(s2_str, nulls_equal=nulls_equal)
467
with pytest.raises(InvalidOperationError):
468
s.is_in(["a", "d", "e"], nulls_equal=nulls_equal)
469
470
out = s.is_in(["a"], nulls_equal=nulls_equal)
471
assert_series_equal(out, expected)
472
473
474
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])
475
@pytest.mark.parametrize("nulls_equal", [False, True])
476
def test_cat_is_in_with_lit_str(dtype: pl.DataType, nulls_equal: bool) -> None:
477
missing_value = False if nulls_equal else None
478
s = pl.Series(["a", "b", "c", None], dtype=dtype)
479
lit = ["b"]
480
expected = pl.Series([False, True, False, missing_value])
481
482
assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected)
483
484
485
@pytest.mark.parametrize("nulls_equal", [False, True])
486
def test_cat_is_in_with_lit_str_non_existent(nulls_equal: bool) -> None:
487
dtype = pl.Categorical()
488
missing_value = False if nulls_equal else None
489
s = pl.Series(["a", "b", "c", None], dtype=dtype)
490
lit = ["d"]
491
expected = pl.Series([False, False, False, missing_value])
492
493
assert_series_equal(s.is_in(lit, nulls_equal=nulls_equal), expected)
494
495
496
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c"])])
497
def test_cat_is_in_with_lit_str_cache_setup(dtype: pl.DataType) -> None:
498
# init the global cache
499
_ = pl.Series(["c", "b", "a"], dtype=dtype)
500
501
assert_series_equal(pl.Series(["a"], dtype=dtype).is_in(["a"]), pl.Series([True]))
502
assert_series_equal(pl.Series(["b"], dtype=dtype).is_in(["b"]), pl.Series([True]))
503
assert_series_equal(pl.Series(["c"], dtype=dtype).is_in(["c"]), pl.Series([True]))
504
505
506
def test_is_in_with_wildcard_13809() -> None:
507
out = pl.DataFrame({"A": ["B"]}).select(pl.all().is_in(["C"]))
508
expected = pl.DataFrame({"A": [False]})
509
assert_frame_equal(out, expected)
510
511
512
@pytest.mark.parametrize(
513
"dtype",
514
[
515
pl.Categorical,
516
pl.Enum(["a", "b", "c", "d"]),
517
],
518
)
519
def test_cat_is_in_from_str(dtype: pl.DataType) -> None:
520
s = pl.Series(["c", "c", "b"], dtype=dtype)
521
522
# test local
523
assert_series_equal(
524
pl.Series(["a", "d", "b"]).is_in(s),
525
pl.Series([False, False, True]),
526
)
527
528
529
@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c", "d"])])
530
def test_cat_list_is_in_from_cat(dtype: pl.DataType) -> None:
531
df = pl.DataFrame(
532
[
533
(["a", "b"], "c"),
534
(["a", "b"], "a"),
535
(["a", None], None),
536
(["a", "c"], None),
537
(["a"], "d"),
538
],
539
schema={"li": pl.List(dtype), "x": dtype},
540
orient="row",
541
)
542
res = df.select(pl.col("li").list.contains(pl.col("x")))
543
expected_df = pl.DataFrame({"li": [False, True, True, False, False]})
544
assert_frame_equal(res, expected_df)
545
546
547
@pytest.mark.parametrize(
548
("val", "expected"),
549
[
550
("b", [True, False, False, None, True]),
551
(None, [False, False, True, None, False]),
552
("e", [False, False, False, None, False]),
553
],
554
)
555
def test_cat_list_is_in_from_cat_single(val: str | None, expected: list[bool]) -> None:
556
df = pl.Series(
557
"li",
558
[["a", "b"], ["a", "c"], ["a", None], None, ["b"]],
559
dtype=pl.List(pl.Categorical),
560
).to_frame()
561
res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.Categorical)))
562
expected_df = pl.DataFrame({"li": expected})
563
assert_frame_equal(res, expected_df)
564
565
566
def test_cat_list_is_in_from_str() -> None:
567
df = pl.DataFrame(
568
[
569
(["a", "b"], "c"),
570
(["a", "b"], "a"),
571
(["a", None], None),
572
(["a", "c"], None),
573
(["a"], "d"),
574
],
575
schema={"li": pl.List(pl.Categorical), "x": pl.String},
576
orient="row",
577
)
578
res = df.select(pl.col("li").list.contains(pl.col("x")))
579
expected_df = pl.DataFrame({"li": [False, True, True, False, False]})
580
assert_frame_equal(res, expected_df)
581
582
583
@pytest.mark.parametrize(
584
("val", "expected"),
585
[
586
("b", [True, False, False, None, True]),
587
(None, [False, False, True, None, False]),
588
("e", [False, False, False, None, False]),
589
],
590
)
591
def test_cat_list_is_in_from_single_str(val: str | None, expected: list[bool]) -> None:
592
df = pl.Series(
593
"li",
594
[["a", "b"], ["a", "c"], ["a", None], None, ["b"]],
595
dtype=pl.List(pl.Categorical),
596
).to_frame()
597
res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.String)))
598
expected_df = pl.DataFrame({"li": expected})
599
assert_frame_equal(res, expected_df)
600
601
602
@pytest.mark.parametrize("nulls_equal", [False, True])
603
def test_is_in_struct_enum_17618(nulls_equal: bool) -> None:
604
df = pl.DataFrame()
605
dtype = pl.Enum(categories=["HBS"])
606
df = df.insert_column(0, pl.Series("category", [], dtype=dtype))
607
assert df.filter(
608
pl.struct("category").is_in(
609
pl.Series(
610
[{"category": "HBS"}],
611
dtype=pl.Struct({"category": df["category"].dtype}),
612
),
613
nulls_equal=nulls_equal,
614
)
615
).shape == (0, 1)
616
617
618
@pytest.mark.parametrize("nulls_equal", [False, True])
619
def test_is_in_decimal(nulls_equal: bool) -> None:
620
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
621
pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal)
622
)["a"].to_list() == [True, False, True]
623
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
624
pl.col("a").is_in([D("0.0"), D("0.1")], nulls_equal=nulls_equal)
625
)["a"].to_list() == [True, False, True]
626
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), D("0.1")]}).select(
627
pl.col("a").is_in([1, 0, 2], nulls_equal=nulls_equal)
628
)["a"].to_list() == [True, False, False]
629
missing_value = True if nulls_equal else None
630
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select(
631
pl.col("a").is_in([0.0, 0.1, None], nulls_equal=nulls_equal)
632
)["a"].to_list() == [True, False, missing_value]
633
missing_value = False if nulls_equal else None
634
assert pl.DataFrame({"a": [D("0.0"), D("0.2"), None]}).select(
635
pl.col("a").is_in([0.0, 0.1], nulls_equal=nulls_equal)
636
)["a"].to_list() == [True, False, missing_value]
637
638
639
def test_is_in_collection() -> None:
640
df = pl.DataFrame(
641
{
642
"lbl": ["aa", "bb", "cc", "dd", "ee"],
643
"val": [0, 1, 2, 3, 4],
644
}
645
)
646
647
class CustomCollection(Collection[int]):
648
def __init__(self, vals: Collection[int]) -> None:
649
super().__init__()
650
self.vals = vals
651
652
def __contains__(self, x: object) -> bool:
653
return x in self.vals
654
655
def __iter__(self) -> Iterator[int]:
656
yield from self.vals
657
658
def __len__(self) -> int:
659
return len(self.vals)
660
661
for constraint_values in (
662
{3, 2, 1},
663
frozenset({3, 2, 1}),
664
CustomCollection([3, 2, 1]),
665
):
666
res = df.filter(pl.col("val").is_in(constraint_values))
667
assert set(res["lbl"]) == {"bb", "cc", "dd"}
668
669
670
@pytest.mark.parametrize("nulls_equal", [False, True])
671
def test_null_propagate_all_paths(nulls_equal: bool) -> None:
672
# No nulls in either
673
s = pl.Series([1, 2, 3])
674
result = s.is_in([1, 3, 8], nulls_equal=nulls_equal)
675
expected = pl.Series([True, False, True])
676
assert_series_equal(result, expected)
677
678
# Nulls in left only
679
s = pl.Series([1, 2, None])
680
result = s.is_in([1, 3, 8], nulls_equal=nulls_equal)
681
missing_value = False if nulls_equal else None
682
expected = pl.Series([True, False, missing_value])
683
assert_series_equal(result, expected)
684
685
# Nulls in right only
686
s = pl.Series([1, 2, 3])
687
result = s.is_in([1, 3, None], nulls_equal=nulls_equal)
688
expected = pl.Series([True, False, True])
689
assert_series_equal(result, expected)
690
691
# Nulls in both
692
s = pl.Series([1, 2, None])
693
result = s.is_in([1, 3, None], nulls_equal=nulls_equal)
694
missing_value = True if nulls_equal else None
695
expected = pl.Series([True, False, missing_value])
696
assert_series_equal(result, expected)
697
698
699
@pytest.mark.parametrize("nulls_equal", [False, True])
700
def test_null_propagate_all_paths_cat(nulls_equal: bool) -> None:
701
# No nulls in either
702
s = pl.Series(["1", "2", "3"])
703
result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal)
704
expected = pl.Series([True, False, True])
705
assert_series_equal(result, expected)
706
707
# Nulls in left only
708
s = pl.Series(["1", "2", None])
709
result = s.is_in(["1", "3", "8"], nulls_equal=nulls_equal)
710
missing_value = False if nulls_equal else None
711
expected = pl.Series([True, False, missing_value])
712
assert_series_equal(result, expected)
713
714
# Nulls in right only
715
s = pl.Series(["1", "2", "3"])
716
result = s.is_in(["1", "3", None], nulls_equal=nulls_equal)
717
expected = pl.Series([True, False, True])
718
assert_series_equal(result, expected)
719
720
# Nulls in both
721
s = pl.Series(["1", "2", None])
722
result = s.is_in(["1", "3", None], nulls_equal=nulls_equal)
723
missing_value = True if nulls_equal else None
724
expected = pl.Series([True, False, missing_value])
725
assert_series_equal(result, expected)
726
727