Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/namespaces/array/test_contains.py
6940 views
1
from __future__ import annotations
2
3
from typing import Any
4
5
import pytest
6
7
import polars as pl
8
from polars.exceptions import SchemaError
9
from polars.testing import assert_series_equal
10
11
12
@pytest.mark.parametrize(
13
("array", "data", "expected", "dtype"),
14
[
15
([[1, 2], [3, 4]], [1, 5], [True, False], pl.Int64),
16
([[True, False], [True, True]], [True, False], [True, False], pl.Boolean),
17
([["a", "b"], ["c", "d"]], ["a", "b"], [True, False], pl.String),
18
([[b"a", b"b"], [b"c", b"d"]], [b"a", b"b"], [True, False], pl.Binary),
19
(
20
[[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]],
21
[{"a": 1}, {"a": 2}],
22
[True, False],
23
pl.Struct([pl.Field("a", pl.Int64)]),
24
),
25
],
26
)
27
def test_array_contains_expr(
28
array: list[list[Any]], data: list[Any], expected: list[bool], dtype: pl.DataType
29
) -> None:
30
df = pl.DataFrame(
31
{
32
"array": array,
33
"data": data,
34
},
35
schema={
36
"array": pl.Array(dtype, 2),
37
"data": dtype,
38
},
39
)
40
out = df.select(contains=pl.col("array").arr.contains(pl.col("data"))).to_series()
41
expected_series = pl.Series("contains", expected)
42
assert_series_equal(out, expected_series)
43
44
45
@pytest.mark.parametrize(
46
("array", "data", "expected", "dtype"),
47
[
48
([[1, 2], [3, 4]], 1, [True, False], pl.Int64),
49
([[True, False], [True, True]], True, [True, True], pl.Boolean),
50
([["a", "b"], ["c", "d"]], "a", [True, False], pl.String),
51
([[b"a", b"b"], [b"c", b"d"]], b"a", [True, False], pl.Binary),
52
],
53
)
54
def test_array_contains_literal(
55
array: list[list[Any]], data: Any, expected: list[bool], dtype: pl.DataType
56
) -> None:
57
df = pl.DataFrame(
58
{
59
"array": array,
60
},
61
schema={
62
"array": pl.Array(dtype, 2),
63
},
64
)
65
out = df.select(contains=pl.col("array").arr.contains(data)).to_series()
66
expected_series = pl.Series("contains", expected)
67
assert_series_equal(out, expected_series)
68
69
70
def test_array_contains_invalid_datatype() -> None:
71
df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.List(pl.Int8)})
72
with pytest.raises(SchemaError, match="invalid series dtype: expected `Array`"):
73
df.select(pl.col("a").arr.contains(2))
74
75