Path: blob/main/py-polars/tests/unit/operations/namespaces/array/test_contains.py
6940 views
from __future__ import annotations12from typing import Any34import pytest56import polars as pl7from polars.exceptions import SchemaError8from polars.testing import assert_series_equal91011@pytest.mark.parametrize(12("array", "data", "expected", "dtype"),13[14([[1, 2], [3, 4]], [1, 5], [True, False], pl.Int64),15([[True, False], [True, True]], [True, False], [True, False], pl.Boolean),16([["a", "b"], ["c", "d"]], ["a", "b"], [True, False], pl.String),17([[b"a", b"b"], [b"c", b"d"]], [b"a", b"b"], [True, False], pl.Binary),18(19[[{"a": 1}, {"a": 2}], [{"b": 1}, {"a": 3}]],20[{"a": 1}, {"a": 2}],21[True, False],22pl.Struct([pl.Field("a", pl.Int64)]),23),24],25)26def test_array_contains_expr(27array: list[list[Any]], data: list[Any], expected: list[bool], dtype: pl.DataType28) -> None:29df = pl.DataFrame(30{31"array": array,32"data": data,33},34schema={35"array": pl.Array(dtype, 2),36"data": dtype,37},38)39out = df.select(contains=pl.col("array").arr.contains(pl.col("data"))).to_series()40expected_series = pl.Series("contains", expected)41assert_series_equal(out, expected_series)424344@pytest.mark.parametrize(45("array", "data", "expected", "dtype"),46[47([[1, 2], [3, 4]], 1, [True, False], pl.Int64),48([[True, False], [True, True]], True, [True, True], pl.Boolean),49([["a", "b"], ["c", "d"]], "a", [True, False], pl.String),50([[b"a", b"b"], [b"c", b"d"]], b"a", [True, False], pl.Binary),51],52)53def test_array_contains_literal(54array: list[list[Any]], data: Any, expected: list[bool], dtype: pl.DataType55) -> None:56df = pl.DataFrame(57{58"array": array,59},60schema={61"array": pl.Array(dtype, 2),62},63)64out = df.select(contains=pl.col("array").arr.contains(data)).to_series()65expected_series = pl.Series("contains", expected)66assert_series_equal(out, expected_series)676869def test_array_contains_invalid_datatype() -> None:70df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.List(pl.Int8)})71with pytest.raises(SchemaError, match="invalid series dtype: expected `Array`"):72df.select(pl.col("a").arr.contains(2))737475