Path: blob/main/py-polars/tests/unit/operations/namespaces/test_binary.py
8398 views
from __future__ import annotations12import random3import struct4from datetime import date, datetime, time, timedelta5from typing import TYPE_CHECKING, Any67import numpy as np8import pytest9from hypothesis import given10from hypothesis import strategies as st1112import polars as pl13from polars.exceptions import InvalidOperationError14from polars.testing import assert_frame_equal, assert_series_equal1516if TYPE_CHECKING:17from polars._typing import PolarsDataType, SizeUnit, TransferEncoding181920def test_binary_conversions() -> None:21df = pl.DataFrame({"blob": [b"abc", None, b"cde"]}).with_columns(22pl.col("blob").cast(pl.String).alias("decoded_blob")23)2425assert df.to_dict(as_series=False) == {26"blob": [b"abc", None, b"cde"],27"decoded_blob": ["abc", None, "cde"],28}29assert df[0, 0] == b"abc"30assert df[1, 0] is None31assert df.dtypes == [pl.Binary, pl.String]323334def test_contains() -> None:35df = pl.DataFrame(36data=[37(1, b"some * * text"),38(2, b"(with) special\n * chars"),39(3, b"**etc...?$"),40(4, None),41],42schema=["idx", "bin"],43orient="row",44)45for pattern, expected in (46(b"e * ", [True, False, False, None]),47(b"text", [True, False, False, None]),48(b"special", [False, True, False, None]),49(b"", [True, True, True, None]),50(b"qwe", [False, False, False, None]),51):52# series53assert expected == df["bin"].bin.contains(pattern).to_list()54# frame select55assert (56expected == df.select(pl.col("bin").bin.contains(pattern))["bin"].to_list()57)58# frame filter59assert sum(e for e in expected if e is True) == len(60df.filter(pl.col("bin").bin.contains(pattern))61)626364def test_contains_with_expr() -> None:65df = pl.DataFrame(66{67"bin": [b"some * * text", b"(with) special\n * chars", b"**etc...?$", None],68"lit1": [b"e * ", b"", b"qwe", b"None"],69"lit2": [None, b"special\n", b"?!", None],70}71)7273assert df.select(74pl.col("bin").bin.contains(pl.col("lit1")).alias("contains_1"),75pl.col("bin").bin.contains(pl.col("lit2")).alias("contains_2"),76pl.col("bin").bin.contains(pl.lit(None)).alias("contains_3"),77).to_dict(as_series=False) == {78"contains_1": [True, True, False, None],79"contains_2": [None, True, False, None],80"contains_3": [None, None, None, None],81}828384def test_starts_ends_with() -> None:85assert pl.DataFrame(86{87"a": [b"hamburger", b"nuts", b"lollypop", None],88"end": [b"ger", b"tg", None, b"anything"],89"start": [b"ha", b"nga", None, b"anything"],90}91).select(92pl.col("a").bin.ends_with(b"pop").alias("end_lit"),93pl.col("a").bin.ends_with(pl.lit(None)).alias("end_none"),94pl.col("a").bin.ends_with(pl.col("end")).alias("end_expr"),95pl.col("a").bin.starts_with(b"ham").alias("start_lit"),96pl.col("a").bin.ends_with(pl.lit(None)).alias("start_none"),97pl.col("a").bin.starts_with(pl.col("start")).alias("start_expr"),98).to_dict(as_series=False) == {99"end_lit": [False, False, True, None],100"end_none": [None, None, None, None],101"end_expr": [True, False, None, None],102"start_lit": [True, False, False, None],103"start_none": [None, None, None, None],104"start_expr": [True, False, None, None],105}106107108def test_base64_encode() -> None:109df = pl.DataFrame({"data": [b"asd", b"qwe"]})110111assert df["data"].bin.encode("base64").to_list() == ["YXNk", "cXdl"]112113114def test_base64_decode() -> None:115df = pl.DataFrame({"data": [b"YXNk", b"cXdl"]})116117assert df["data"].bin.decode("base64").to_list() == [b"asd", b"qwe"]118119120def test_hex_encode() -> None:121df = pl.DataFrame({"data": [b"asd", b"qwe"]})122123assert df["data"].bin.encode("hex").to_list() == ["617364", "717765"]124125126def test_hex_decode() -> None:127df = pl.DataFrame({"data": [b"617364", b"717765"]})128129assert df["data"].bin.decode("hex").to_list() == [b"asd", b"qwe"]130131132@pytest.mark.parametrize(133"encoding",134["hex", "base64"],135)136def test_compare_encode_between_lazy_and_eager_6814(encoding: TransferEncoding) -> None:137df = pl.DataFrame({"x": [b"aa", b"bb", b"cc"]})138expr = pl.col("x").bin.encode(encoding)139140result_eager = df.select(expr)141dtype = result_eager["x"].dtype142143result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect()144assert_frame_equal(result_eager, result_lazy)145146147@pytest.mark.parametrize(148"encoding",149["hex", "base64"],150)151def test_compare_decode_between_lazy_and_eager_6814(encoding: TransferEncoding) -> None:152df = pl.DataFrame({"x": [b"d3d3", b"abcd", b"1234"]})153expr = pl.col("x").bin.decode(encoding)154155result_eager = df.select(expr)156dtype = result_eager["x"].dtype157158result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect()159assert_frame_equal(result_eager, result_lazy)160161162@pytest.mark.parametrize(163("sz", "unit", "expected"),164[(128, "b", 128), (512, "kb", 0.5), (131072, "mb", 0.125)],165)166def test_binary_size(sz: int, unit: SizeUnit, expected: int | float) -> None:167df = pl.DataFrame({"data": [b"\x00" * sz]}, schema={"data": pl.Binary})168for sz in (169df.select(sz=pl.col("data").bin.size(unit)).item(), # expr170df["data"].bin.size(unit).item(), # series171):172assert sz == expected173174175@pytest.mark.parametrize(176("dtype", "type_size", "struct_type"),177[178(pl.Int8, 1, "b"),179(pl.UInt8, 1, "B"),180(pl.Int16, 2, "h"),181(pl.UInt16, 2, "H"),182(pl.Int32, 4, "i"),183(pl.UInt32, 4, "I"),184(pl.Int64, 8, "q"),185(pl.UInt64, 8, "Q"),186(pl.Float32, 4, "f"),187(pl.Float64, 8, "d"),188],189)190def test_reinterpret(191dtype: pl.DataType,192type_size: int,193struct_type: str,194) -> None:195# Make test reproducible196random.seed(42)197198byte_arr = [random.randbytes(type_size) for _ in range(3)]199df = pl.DataFrame({"x": byte_arr})200201for endianness in ["little", "big"]:202# So that mypy doesn't complain203struct_endianness = "<" if endianness == "little" else ">"204expected = [205struct.unpack_from(f"{struct_endianness}{struct_type}", elem_bytes)[0]206for elem_bytes in byte_arr207]208expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype})209210result = df.select(211pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type]212)213214assert_frame_equal(result, expected_df)215216217@pytest.mark.parametrize(218("dtype", "inner_type_size", "struct_type"),219[220(pl.Array(pl.Int8, 3), 1, "b"),221(pl.Array(pl.UInt8, 3), 1, "B"),222(pl.Array(pl.Int16, 3), 2, "h"),223(pl.Array(pl.UInt16, 3), 2, "H"),224(pl.Array(pl.Int32, 3), 4, "i"),225(pl.Array(pl.UInt32, 3), 4, "I"),226(pl.Array(pl.Int64, 3), 8, "q"),227(pl.Array(pl.UInt64, 3), 8, "Q"),228(pl.Array(pl.Float32, 3), 4, "f"),229(pl.Array(pl.Float64, 3), 8, "d"),230],231)232def test_reinterpret_to_array_numeric_types(233dtype: pl.Array,234inner_type_size: int,235struct_type: str,236) -> None:237# Make test reproducible238random.seed(42)239240type_size = inner_type_size241shape = dtype.shape242if isinstance(shape, int):243shape = (shape,)244for dim_size in dtype.shape:245type_size *= dim_size246247byte_arr = [random.randbytes(type_size) for _ in range(3)]248df = pl.DataFrame({"x": byte_arr}, orient="row")249250for endianness in ["little", "big"]:251result = df.select(252pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type]253)254255# So that mypy doesn't complain256struct_endianness = "<" if endianness == "little" else ">"257expected = []258for elem_bytes in byte_arr:259vals = [260struct.unpack_from(261f"{struct_endianness}{struct_type}",262elem_bytes[idx : idx + inner_type_size],263)[0]264for idx in range(0, type_size, inner_type_size)265]266if len(shape) > 1:267vals = np.reshape(vals, shape).tolist()268expected.append(vals)269expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype})270271assert_frame_equal(result, expected_df)272273274@pytest.mark.parametrize(275("dtype", "binary_value", "expected_values"),276[277(pl.Date(), b"\x06\x00\x00\x00", [date(1970, 1, 7)]),278(279pl.Datetime(),280b"\x40\xb6\xfd\xe3\x7c\x00\x00\x00",281[datetime(1970, 1, 7, 5, 0, 1)],282),283(284pl.Duration(),285b"\x03\x00\x00\x00\x00\x00\x00\x00",286[timedelta(microseconds=3)],287),288(289pl.Time(),290b"\x58\x1b\x00\x00\x00\x00\x00\x00",291[time(microsecond=7)],292),293(294pl.Int128(),295b"\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",296[6],297),298(299pl.UInt128(),300b"\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",301[6],302),303],304)305def test_reinterpret_to_additional_types(306dtype: PolarsDataType, binary_value: bytes, expected_values: list[object]307) -> None:308series = pl.Series([binary_value])309310# Direct conversion:311result = series.bin.reinterpret(dtype=dtype, endianness="little")312assert_series_equal(result, pl.Series(expected_values, dtype=dtype))313314# Array conversion:315dtype = pl.Array(dtype, 1)316result = series.bin.reinterpret(dtype=dtype, endianness="little")317assert_series_equal(result, pl.Series([expected_values], dtype=dtype))318319320def test_reinterpret_to_array_resulting_in_nulls() -> None:321series = pl.Series([None, b"short", b"justrite", None, b"waytoolong"])322as_bin = series.bin.reinterpret(dtype=pl.Array(pl.UInt32(), 2), endianness="little")323assert as_bin.to_list() == [None, None, [0x7473756A, 0x65746972], None, None]324as_bin = series.bin.reinterpret(dtype=pl.Array(pl.UInt32(), 2), endianness="big")325assert as_bin.to_list() == [None, None, [0x6A757374, 0x72697465], None, None]326327328def test_reinterpret_to_n_dimensional_array() -> None:329series = pl.Series([b"abcd"])330for endianness in ["big", "little"]:331with pytest.raises(332InvalidOperationError,333match="reinterpret to a linear Array, and then use reshape",334):335series.bin.reinterpret(336dtype=pl.Array(pl.UInt32(), (2, 2)),337endianness=endianness, # type: ignore[arg-type]338)339340341def test_reinterpret_to_zero_length_array() -> None:342arr_dtype = pl.Array(pl.UInt8, 0)343result = pl.Series([b"", b""]).bin.reinterpret(dtype=arr_dtype)344assert_series_equal(result, pl.Series([[], []], dtype=arr_dtype))345346347@given(348value1=st.integers(0, 2**63),349value2=st.binary(min_size=0, max_size=7),350value3=st.integers(0, 2**63),351)352def test_reinterpret_to_array_different_alignment(353value1: int, value2: bytes, value3: int354) -> None:355series = pl.Series([struct.pack("<Q", value1), value2, struct.pack("<Q", value3)])356arr_dtype = pl.Array(pl.UInt64, 1)357as_uint64 = series.bin.reinterpret(dtype=arr_dtype, endianness="little")358assert_series_equal(359pl.Series([[value1], None, [value3]], dtype=arr_dtype), as_uint64360)361362363@pytest.mark.parametrize(364"bad_dtype",365[366pl.Array(pl.Array(pl.UInt8, 1), 1),367pl.String(),368pl.Array(pl.List(pl.UInt8()), 1),369pl.Array(pl.Null(), 1),370pl.Array(pl.Boolean(), 1),371],372)373def test_reinterpret_unsupported(bad_dtype: pl.DataType) -> None:374series = pl.Series([b"12345678"])375lazy_df = pl.DataFrame({"s": series}).lazy()376expected = "cannot reinterpret binary to dtype.*Only numeric or temporal dtype.*"377for endianness in ["little", "big"]:378with pytest.raises(InvalidOperationError, match=expected):379series.bin.reinterpret(dtype=bad_dtype, endianness=endianness) # type: ignore[arg-type]380with pytest.raises(InvalidOperationError, match=expected):381lazy_df.select(382pl.col("s").bin.reinterpret(dtype=bad_dtype, endianness=endianness) # type: ignore[arg-type]383).collect_schema()384385386@pytest.mark.parametrize(387("dtype", "type_size"),388[389(pl.Int128, 16),390],391)392def test_reinterpret_int(393dtype: pl.DataType,394type_size: int,395) -> None:396# Function used for testing integers that `struct` or `numpy`397# doesn't support parsing from bytes.398# Rather than creating bytes directly, create integer and view it as bytes399is_signed = dtype.is_signed_integer()400401if is_signed:402min_val = -(2 ** (type_size - 1))403max_val = 2 ** (type_size - 1) - 1404else:405min_val = 0406max_val = 2**type_size - 1407408# Make test reproducible409random.seed(42)410411expected = [random.randint(min_val, max_val) for _ in range(3)]412expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype})413414for endianness in ["little", "big"]:415byte_arr = [416val.to_bytes(type_size, byteorder=endianness, signed=is_signed) # type: ignore[arg-type]417for val in expected418]419df = pl.DataFrame({"x": byte_arr})420421result = df.select(422pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type]423)424425assert_frame_equal(result, expected_df)426427428def test_reinterpret_invalid() -> None:429# Fails because buffer has more than 4 bytes430df = pl.DataFrame({"x": [b"d3d3a"]})431print(struct.unpack_from("<i", b"d3d3a"))432assert_frame_equal(433df.select(pl.col("x").bin.reinterpret(dtype=pl.Int32)),434pl.DataFrame({"x": [None]}, schema={"x": pl.Int32}),435)436437# Fails because buffer has less than 4 bytes438df = pl.DataFrame({"x": [b"d3"]})439print(df.select(pl.col("x").bin.reinterpret(dtype=pl.Int32)))440assert_frame_equal(441df.select(pl.col("x").bin.reinterpret(dtype=pl.Int32)),442pl.DataFrame({"x": [None]}, schema={"x": pl.Int32}),443)444445# Fails because dtype is invalid446with pytest.raises(pl.exceptions.InvalidOperationError):447df.select(pl.col("x").bin.reinterpret(dtype=pl.String))448449450@pytest.mark.parametrize("func", ["contains", "starts_with", "ends_with"])451def test_bin_contains_unequal_lengths_22018(func: str) -> None:452s = pl.Series("a", [b"a", b"xyz"], pl.Binary).bin453f = getattr(s, func)454with pytest.raises(pl.exceptions.ShapeError):455f(pl.Series([b"x", b"y", b"z"]))456457458def test_binary_compounded_literal_aggstate_24460() -> None:459df = pl.DataFrame({"g": [10], "n": [1]})460out = df.group_by("g").agg(461(pl.lit(1, pl.Int64) + pl.lit(2)).pow(pl.lit(3)).alias("z")462)463expected = pl.DataFrame({"g": [10], "z": [27]})464assert_frame_equal(out, expected)465466467# parametric tuples: (expr, is_scalar, values with broadcast)468agg_expressions = [469(pl.lit(7, pl.Int64), True, [7, 7, 7]), # LiteralScalar470(pl.col("n"), False, [2, 1, 3]), # NotAggregated471(pl.int_range(pl.len()), False, [0, 1, 0]), # AggregatedList472(pl.col("n").first(), True, [2, 2, 3]), # AggregatedScalar473]474475476@pytest.mark.parametrize("lhs", agg_expressions)477@pytest.mark.parametrize("rhs", agg_expressions)478@pytest.mark.parametrize("n_rows", [0, 1, 2, 3])479@pytest.mark.parametrize("maintain_order", [True, False])480def test_add_aggstates_in_binary_expr_24504(481lhs: tuple[pl.Expr, bool, list[int]],482rhs: tuple[pl.Expr, bool, list[int]],483n_rows: int,484maintain_order: bool,485) -> None:486df = pl.DataFrame({"g": [10, 10, 20], "n": [2, 1, 3]})487lf = df.head(n_rows).lazy()488expr = pl.Expr.add(lhs[0].alias("lhs"), rhs[0].alias("rhs")).alias("expr")489q = lf.group_by("g", maintain_order=maintain_order).agg(expr)490out = q.collect()491492# check schema493assert q.collect_schema() == out.schema494495# check output against ground truth496if n_rows in [1, 2, 3]:497data = df.to_dict(as_series=False)498result: dict[int, Any] = {}499for gg, ll, rr in zip(500data["g"][:n_rows], lhs[2][:n_rows], rhs[2][:n_rows], strict=True501):502result.setdefault(gg, []).append(ll + rr)503if lhs[1] and rhs[1]:504# expect scalar result505result = {k: v[0] for k, v in result.items()}506expected = pl.DataFrame(507{"g": list(result.keys()), "expr": list(result.values())}508)509assert_frame_equal(out, expected, check_row_order=maintain_order)510511# check output against non_aggregated expression evaluation512if n_rows in [1, 2, 3]:513print(f"df\n{df}")514grouped = df.head(n_rows).group_by("g", maintain_order=maintain_order)515out_non_agg = pl.DataFrame({})516for df_group in grouped:517df = df_group[1]518print(f"df pre expr:\n{df}", flush=True)519if lhs[1] and rhs[1]:520df = df.head(1)521df = df.select(["g", expr])522else:523df = df.select(["g", expr.implode()]).head(1)524print(f"df post expr:{df}\n")525out_non_agg = out_non_agg.vstack(df)526print(f"out_non_agg:\n{out_non_agg}")527528assert_frame_equal(out, out_non_agg, check_row_order=maintain_order)529530531# parametric tuples: (expr, is_scalar)532agg_expressions_sort = [533(pl.lit(7, pl.Int64), True), # LiteralScalar534(pl.col("n"), False), # NotAggregated535(pl.col("n").sort(), False), # NotAggregated w groups modified536(pl.int_range(pl.len()), False), # AggregatedList537(pl.int_range(pl.len()).reverse(), False), # AggregatedList w groups modified538(pl.col("n").first(), True), # AggregatedScalar539]540541542@pytest.mark.parametrize("lhs", agg_expressions_sort)543@pytest.mark.parametrize("rhs", agg_expressions_sort)544@pytest.mark.parametrize("maintain_order", [True, False])545def test_add_aggstates_with_sort_in_binary_expr_24504(546lhs: tuple[pl.Expr, bool, list[int]],547rhs: tuple[pl.Expr, bool, list[int]],548maintain_order: bool,549) -> None:550df = pl.DataFrame({"g": [10, 10, 20], "n": [2, 1, 3]})551lf = df.lazy()552expr = pl.Expr.add(lhs[0].alias("lhs"), rhs[0].alias("rhs")).alias("expr")553q = lf.group_by("g", maintain_order=maintain_order).agg(expr)554out = q.collect()555556# check schema557assert q.collect_schema() == out.schema558559# check output against non_aggregated expression evaluation560grouped = df.group_by("g", maintain_order=maintain_order)561out_non_agg = pl.DataFrame({})562for df_group in grouped:563df = df_group[1]564if lhs[1] and rhs[1]:565df = df.head(1)566df = df.select(["g", expr])567else:568df = df.select(["g", expr.implode()]).head(1)569out_non_agg = out_non_agg.vstack(df)570571assert_frame_equal(out, out_non_agg, check_row_order=maintain_order)572573574@pytest.mark.parametrize("maintain_order", [True, False])575def test_binary_context_nested(maintain_order: bool) -> None:576df = pl.DataFrame({"groups": [1, 1, 2, 2, 3, 3], "vals": [1, 13, 3, 87, 1, 6]})577out = (578df.lazy()579.group_by(pl.col("groups"), maintain_order=maintain_order)580.agg(581[582pl.when(pl.col("vals").eq(pl.lit(1)))583.then(pl.col("vals").sum())584.otherwise(pl.lit(90))585.alias("vals")586]587)588).collect()589expected = pl.DataFrame(590{"groups": [1, 2, 3], "vals": [[14, 90], [90, 90], [7, 90]]}591)592assert_frame_equal(out, expected, check_row_order=maintain_order)593594595def test_get() -> None:596# N binary, scalar index (N to 1).597df = pl.DataFrame({"a": [b"\x01\x02\x03", b"", b"\x04\x05"]})598result = df.select(pl.col("a").bin.get(0, null_on_oob=True))599expected = pl.DataFrame({"a": [1, None, 4]}, schema={"a": pl.UInt8})600assert_frame_equal(result, expected)601602# Negative index.603result = df.select(pl.col("a").bin.get(-1, null_on_oob=True))604expected = pl.DataFrame({"a": [3, None, 5]}, schema={"a": pl.UInt8})605assert_frame_equal(result, expected)606607# Null index.608result = df.select(609pl.col("a").bin.get(pl.lit(None, dtype=pl.Int64), null_on_oob=True)610)611expected = pl.DataFrame({"a": [None, None, None]}, schema={"a": pl.UInt8})612assert_frame_equal(result, expected)613614# N binary, N indices (N to N).615df = pl.DataFrame(616{617"a": [b"\x01\x02\x03", b"\x04\x05", b"\x06"],618"idx": [2, 0, 0],619}620)621result = df.select(pl.col("a").bin.get(pl.col("idx"), null_on_oob=True))622expected = pl.DataFrame({"a": [3, 4, 6]}, schema={"a": pl.UInt8})623assert_frame_equal(result, expected)624625# 1 binary, N indices (1 to N).626result = pl.select(627pl.lit(pl.Series("a", [b"\x01\x02\x03"])).bin.get(628pl.Series("idx", [0, 1, 2]), null_on_oob=True629)630)631expected = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8})632assert_frame_equal(result, expected)633634# OOB raises error.635df = pl.DataFrame({"a": [b"\x01\x02"]})636with pytest.raises(pl.exceptions.ComputeError, match="out of bounds"):637df.select(pl.col("a").bin.get(5))638639640