Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/namespaces/test_binary.py
8398 views
1
from __future__ import annotations
2
3
import random
4
import struct
5
from datetime import date, datetime, time, timedelta
6
from typing import TYPE_CHECKING, Any
7
8
import numpy as np
9
import pytest
10
from hypothesis import given
11
from hypothesis import strategies as st
12
13
import polars as pl
14
from polars.exceptions import InvalidOperationError
15
from polars.testing import assert_frame_equal, assert_series_equal
16
17
if TYPE_CHECKING:
18
from polars._typing import PolarsDataType, SizeUnit, TransferEncoding
19
20
21
def test_binary_conversions() -> None:
22
df = pl.DataFrame({"blob": [b"abc", None, b"cde"]}).with_columns(
23
pl.col("blob").cast(pl.String).alias("decoded_blob")
24
)
25
26
assert df.to_dict(as_series=False) == {
27
"blob": [b"abc", None, b"cde"],
28
"decoded_blob": ["abc", None, "cde"],
29
}
30
assert df[0, 0] == b"abc"
31
assert df[1, 0] is None
32
assert df.dtypes == [pl.Binary, pl.String]
33
34
35
def test_contains() -> None:
36
df = pl.DataFrame(
37
data=[
38
(1, b"some * * text"),
39
(2, b"(with) special\n * chars"),
40
(3, b"**etc...?$"),
41
(4, None),
42
],
43
schema=["idx", "bin"],
44
orient="row",
45
)
46
for pattern, expected in (
47
(b"e * ", [True, False, False, None]),
48
(b"text", [True, False, False, None]),
49
(b"special", [False, True, False, None]),
50
(b"", [True, True, True, None]),
51
(b"qwe", [False, False, False, None]),
52
):
53
# series
54
assert expected == df["bin"].bin.contains(pattern).to_list()
55
# frame select
56
assert (
57
expected == df.select(pl.col("bin").bin.contains(pattern))["bin"].to_list()
58
)
59
# frame filter
60
assert sum(e for e in expected if e is True) == len(
61
df.filter(pl.col("bin").bin.contains(pattern))
62
)
63
64
65
def test_contains_with_expr() -> None:
66
df = pl.DataFrame(
67
{
68
"bin": [b"some * * text", b"(with) special\n * chars", b"**etc...?$", None],
69
"lit1": [b"e * ", b"", b"qwe", b"None"],
70
"lit2": [None, b"special\n", b"?!", None],
71
}
72
)
73
74
assert df.select(
75
pl.col("bin").bin.contains(pl.col("lit1")).alias("contains_1"),
76
pl.col("bin").bin.contains(pl.col("lit2")).alias("contains_2"),
77
pl.col("bin").bin.contains(pl.lit(None)).alias("contains_3"),
78
).to_dict(as_series=False) == {
79
"contains_1": [True, True, False, None],
80
"contains_2": [None, True, False, None],
81
"contains_3": [None, None, None, None],
82
}
83
84
85
def test_starts_ends_with() -> None:
86
assert pl.DataFrame(
87
{
88
"a": [b"hamburger", b"nuts", b"lollypop", None],
89
"end": [b"ger", b"tg", None, b"anything"],
90
"start": [b"ha", b"nga", None, b"anything"],
91
}
92
).select(
93
pl.col("a").bin.ends_with(b"pop").alias("end_lit"),
94
pl.col("a").bin.ends_with(pl.lit(None)).alias("end_none"),
95
pl.col("a").bin.ends_with(pl.col("end")).alias("end_expr"),
96
pl.col("a").bin.starts_with(b"ham").alias("start_lit"),
97
pl.col("a").bin.ends_with(pl.lit(None)).alias("start_none"),
98
pl.col("a").bin.starts_with(pl.col("start")).alias("start_expr"),
99
).to_dict(as_series=False) == {
100
"end_lit": [False, False, True, None],
101
"end_none": [None, None, None, None],
102
"end_expr": [True, False, None, None],
103
"start_lit": [True, False, False, None],
104
"start_none": [None, None, None, None],
105
"start_expr": [True, False, None, None],
106
}
107
108
109
def test_base64_encode() -> None:
110
df = pl.DataFrame({"data": [b"asd", b"qwe"]})
111
112
assert df["data"].bin.encode("base64").to_list() == ["YXNk", "cXdl"]
113
114
115
def test_base64_decode() -> None:
116
df = pl.DataFrame({"data": [b"YXNk", b"cXdl"]})
117
118
assert df["data"].bin.decode("base64").to_list() == [b"asd", b"qwe"]
119
120
121
def test_hex_encode() -> None:
122
df = pl.DataFrame({"data": [b"asd", b"qwe"]})
123
124
assert df["data"].bin.encode("hex").to_list() == ["617364", "717765"]
125
126
127
def test_hex_decode() -> None:
128
df = pl.DataFrame({"data": [b"617364", b"717765"]})
129
130
assert df["data"].bin.decode("hex").to_list() == [b"asd", b"qwe"]
131
132
133
@pytest.mark.parametrize(
134
"encoding",
135
["hex", "base64"],
136
)
137
def test_compare_encode_between_lazy_and_eager_6814(encoding: TransferEncoding) -> None:
138
df = pl.DataFrame({"x": [b"aa", b"bb", b"cc"]})
139
expr = pl.col("x").bin.encode(encoding)
140
141
result_eager = df.select(expr)
142
dtype = result_eager["x"].dtype
143
144
result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect()
145
assert_frame_equal(result_eager, result_lazy)
146
147
148
@pytest.mark.parametrize(
149
"encoding",
150
["hex", "base64"],
151
)
152
def test_compare_decode_between_lazy_and_eager_6814(encoding: TransferEncoding) -> None:
153
df = pl.DataFrame({"x": [b"d3d3", b"abcd", b"1234"]})
154
expr = pl.col("x").bin.decode(encoding)
155
156
result_eager = df.select(expr)
157
dtype = result_eager["x"].dtype
158
159
result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect()
160
assert_frame_equal(result_eager, result_lazy)
161
162
163
@pytest.mark.parametrize(
164
("sz", "unit", "expected"),
165
[(128, "b", 128), (512, "kb", 0.5), (131072, "mb", 0.125)],
166
)
167
def test_binary_size(sz: int, unit: SizeUnit, expected: int | float) -> None:
168
df = pl.DataFrame({"data": [b"\x00" * sz]}, schema={"data": pl.Binary})
169
for sz in (
170
df.select(sz=pl.col("data").bin.size(unit)).item(), # expr
171
df["data"].bin.size(unit).item(), # series
172
):
173
assert sz == expected
174
175
176
@pytest.mark.parametrize(
177
("dtype", "type_size", "struct_type"),
178
[
179
(pl.Int8, 1, "b"),
180
(pl.UInt8, 1, "B"),
181
(pl.Int16, 2, "h"),
182
(pl.UInt16, 2, "H"),
183
(pl.Int32, 4, "i"),
184
(pl.UInt32, 4, "I"),
185
(pl.Int64, 8, "q"),
186
(pl.UInt64, 8, "Q"),
187
(pl.Float32, 4, "f"),
188
(pl.Float64, 8, "d"),
189
],
190
)
191
def test_reinterpret(
192
dtype: pl.DataType,
193
type_size: int,
194
struct_type: str,
195
) -> None:
196
# Make test reproducible
197
random.seed(42)
198
199
byte_arr = [random.randbytes(type_size) for _ in range(3)]
200
df = pl.DataFrame({"x": byte_arr})
201
202
for endianness in ["little", "big"]:
203
# So that mypy doesn't complain
204
struct_endianness = "<" if endianness == "little" else ">"
205
expected = [
206
struct.unpack_from(f"{struct_endianness}{struct_type}", elem_bytes)[0]
207
for elem_bytes in byte_arr
208
]
209
expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype})
210
211
result = df.select(
212
pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type]
213
)
214
215
assert_frame_equal(result, expected_df)
216
217
218
@pytest.mark.parametrize(
219
("dtype", "inner_type_size", "struct_type"),
220
[
221
(pl.Array(pl.Int8, 3), 1, "b"),
222
(pl.Array(pl.UInt8, 3), 1, "B"),
223
(pl.Array(pl.Int16, 3), 2, "h"),
224
(pl.Array(pl.UInt16, 3), 2, "H"),
225
(pl.Array(pl.Int32, 3), 4, "i"),
226
(pl.Array(pl.UInt32, 3), 4, "I"),
227
(pl.Array(pl.Int64, 3), 8, "q"),
228
(pl.Array(pl.UInt64, 3), 8, "Q"),
229
(pl.Array(pl.Float32, 3), 4, "f"),
230
(pl.Array(pl.Float64, 3), 8, "d"),
231
],
232
)
233
def test_reinterpret_to_array_numeric_types(
234
dtype: pl.Array,
235
inner_type_size: int,
236
struct_type: str,
237
) -> None:
238
# Make test reproducible
239
random.seed(42)
240
241
type_size = inner_type_size
242
shape = dtype.shape
243
if isinstance(shape, int):
244
shape = (shape,)
245
for dim_size in dtype.shape:
246
type_size *= dim_size
247
248
byte_arr = [random.randbytes(type_size) for _ in range(3)]
249
df = pl.DataFrame({"x": byte_arr}, orient="row")
250
251
for endianness in ["little", "big"]:
252
result = df.select(
253
pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type]
254
)
255
256
# So that mypy doesn't complain
257
struct_endianness = "<" if endianness == "little" else ">"
258
expected = []
259
for elem_bytes in byte_arr:
260
vals = [
261
struct.unpack_from(
262
f"{struct_endianness}{struct_type}",
263
elem_bytes[idx : idx + inner_type_size],
264
)[0]
265
for idx in range(0, type_size, inner_type_size)
266
]
267
if len(shape) > 1:
268
vals = np.reshape(vals, shape).tolist()
269
expected.append(vals)
270
expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype})
271
272
assert_frame_equal(result, expected_df)
273
274
275
@pytest.mark.parametrize(
276
("dtype", "binary_value", "expected_values"),
277
[
278
(pl.Date(), b"\x06\x00\x00\x00", [date(1970, 1, 7)]),
279
(
280
pl.Datetime(),
281
b"\x40\xb6\xfd\xe3\x7c\x00\x00\x00",
282
[datetime(1970, 1, 7, 5, 0, 1)],
283
),
284
(
285
pl.Duration(),
286
b"\x03\x00\x00\x00\x00\x00\x00\x00",
287
[timedelta(microseconds=3)],
288
),
289
(
290
pl.Time(),
291
b"\x58\x1b\x00\x00\x00\x00\x00\x00",
292
[time(microsecond=7)],
293
),
294
(
295
pl.Int128(),
296
b"\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
297
[6],
298
),
299
(
300
pl.UInt128(),
301
b"\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
302
[6],
303
),
304
],
305
)
306
def test_reinterpret_to_additional_types(
307
dtype: PolarsDataType, binary_value: bytes, expected_values: list[object]
308
) -> None:
309
series = pl.Series([binary_value])
310
311
# Direct conversion:
312
result = series.bin.reinterpret(dtype=dtype, endianness="little")
313
assert_series_equal(result, pl.Series(expected_values, dtype=dtype))
314
315
# Array conversion:
316
dtype = pl.Array(dtype, 1)
317
result = series.bin.reinterpret(dtype=dtype, endianness="little")
318
assert_series_equal(result, pl.Series([expected_values], dtype=dtype))
319
320
321
def test_reinterpret_to_array_resulting_in_nulls() -> None:
322
series = pl.Series([None, b"short", b"justrite", None, b"waytoolong"])
323
as_bin = series.bin.reinterpret(dtype=pl.Array(pl.UInt32(), 2), endianness="little")
324
assert as_bin.to_list() == [None, None, [0x7473756A, 0x65746972], None, None]
325
as_bin = series.bin.reinterpret(dtype=pl.Array(pl.UInt32(), 2), endianness="big")
326
assert as_bin.to_list() == [None, None, [0x6A757374, 0x72697465], None, None]
327
328
329
def test_reinterpret_to_n_dimensional_array() -> None:
330
series = pl.Series([b"abcd"])
331
for endianness in ["big", "little"]:
332
with pytest.raises(
333
InvalidOperationError,
334
match="reinterpret to a linear Array, and then use reshape",
335
):
336
series.bin.reinterpret(
337
dtype=pl.Array(pl.UInt32(), (2, 2)),
338
endianness=endianness, # type: ignore[arg-type]
339
)
340
341
342
def test_reinterpret_to_zero_length_array() -> None:
343
arr_dtype = pl.Array(pl.UInt8, 0)
344
result = pl.Series([b"", b""]).bin.reinterpret(dtype=arr_dtype)
345
assert_series_equal(result, pl.Series([[], []], dtype=arr_dtype))
346
347
348
@given(
349
value1=st.integers(0, 2**63),
350
value2=st.binary(min_size=0, max_size=7),
351
value3=st.integers(0, 2**63),
352
)
353
def test_reinterpret_to_array_different_alignment(
354
value1: int, value2: bytes, value3: int
355
) -> None:
356
series = pl.Series([struct.pack("<Q", value1), value2, struct.pack("<Q", value3)])
357
arr_dtype = pl.Array(pl.UInt64, 1)
358
as_uint64 = series.bin.reinterpret(dtype=arr_dtype, endianness="little")
359
assert_series_equal(
360
pl.Series([[value1], None, [value3]], dtype=arr_dtype), as_uint64
361
)
362
363
364
@pytest.mark.parametrize(
365
"bad_dtype",
366
[
367
pl.Array(pl.Array(pl.UInt8, 1), 1),
368
pl.String(),
369
pl.Array(pl.List(pl.UInt8()), 1),
370
pl.Array(pl.Null(), 1),
371
pl.Array(pl.Boolean(), 1),
372
],
373
)
374
def test_reinterpret_unsupported(bad_dtype: pl.DataType) -> None:
375
series = pl.Series([b"12345678"])
376
lazy_df = pl.DataFrame({"s": series}).lazy()
377
expected = "cannot reinterpret binary to dtype.*Only numeric or temporal dtype.*"
378
for endianness in ["little", "big"]:
379
with pytest.raises(InvalidOperationError, match=expected):
380
series.bin.reinterpret(dtype=bad_dtype, endianness=endianness) # type: ignore[arg-type]
381
with pytest.raises(InvalidOperationError, match=expected):
382
lazy_df.select(
383
pl.col("s").bin.reinterpret(dtype=bad_dtype, endianness=endianness) # type: ignore[arg-type]
384
).collect_schema()
385
386
387
@pytest.mark.parametrize(
388
("dtype", "type_size"),
389
[
390
(pl.Int128, 16),
391
],
392
)
393
def test_reinterpret_int(
394
dtype: pl.DataType,
395
type_size: int,
396
) -> None:
397
# Function used for testing integers that `struct` or `numpy`
398
# doesn't support parsing from bytes.
399
# Rather than creating bytes directly, create integer and view it as bytes
400
is_signed = dtype.is_signed_integer()
401
402
if is_signed:
403
min_val = -(2 ** (type_size - 1))
404
max_val = 2 ** (type_size - 1) - 1
405
else:
406
min_val = 0
407
max_val = 2**type_size - 1
408
409
# Make test reproducible
410
random.seed(42)
411
412
expected = [random.randint(min_val, max_val) for _ in range(3)]
413
expected_df = pl.DataFrame({"x": expected}, schema={"x": dtype})
414
415
for endianness in ["little", "big"]:
416
byte_arr = [
417
val.to_bytes(type_size, byteorder=endianness, signed=is_signed) # type: ignore[arg-type]
418
for val in expected
419
]
420
df = pl.DataFrame({"x": byte_arr})
421
422
result = df.select(
423
pl.col("x").bin.reinterpret(dtype=dtype, endianness=endianness) # type: ignore[arg-type]
424
)
425
426
assert_frame_equal(result, expected_df)
427
428
429
def test_reinterpret_invalid() -> None:
430
# Fails because buffer has more than 4 bytes
431
df = pl.DataFrame({"x": [b"d3d3a"]})
432
print(struct.unpack_from("<i", b"d3d3a"))
433
assert_frame_equal(
434
df.select(pl.col("x").bin.reinterpret(dtype=pl.Int32)),
435
pl.DataFrame({"x": [None]}, schema={"x": pl.Int32}),
436
)
437
438
# Fails because buffer has less than 4 bytes
439
df = pl.DataFrame({"x": [b"d3"]})
440
print(df.select(pl.col("x").bin.reinterpret(dtype=pl.Int32)))
441
assert_frame_equal(
442
df.select(pl.col("x").bin.reinterpret(dtype=pl.Int32)),
443
pl.DataFrame({"x": [None]}, schema={"x": pl.Int32}),
444
)
445
446
# Fails because dtype is invalid
447
with pytest.raises(pl.exceptions.InvalidOperationError):
448
df.select(pl.col("x").bin.reinterpret(dtype=pl.String))
449
450
451
@pytest.mark.parametrize("func", ["contains", "starts_with", "ends_with"])
452
def test_bin_contains_unequal_lengths_22018(func: str) -> None:
453
s = pl.Series("a", [b"a", b"xyz"], pl.Binary).bin
454
f = getattr(s, func)
455
with pytest.raises(pl.exceptions.ShapeError):
456
f(pl.Series([b"x", b"y", b"z"]))
457
458
459
def test_binary_compounded_literal_aggstate_24460() -> None:
460
df = pl.DataFrame({"g": [10], "n": [1]})
461
out = df.group_by("g").agg(
462
(pl.lit(1, pl.Int64) + pl.lit(2)).pow(pl.lit(3)).alias("z")
463
)
464
expected = pl.DataFrame({"g": [10], "z": [27]})
465
assert_frame_equal(out, expected)
466
467
468
# parametric tuples: (expr, is_scalar, values with broadcast)
469
agg_expressions = [
470
(pl.lit(7, pl.Int64), True, [7, 7, 7]), # LiteralScalar
471
(pl.col("n"), False, [2, 1, 3]), # NotAggregated
472
(pl.int_range(pl.len()), False, [0, 1, 0]), # AggregatedList
473
(pl.col("n").first(), True, [2, 2, 3]), # AggregatedScalar
474
]
475
476
477
@pytest.mark.parametrize("lhs", agg_expressions)
478
@pytest.mark.parametrize("rhs", agg_expressions)
479
@pytest.mark.parametrize("n_rows", [0, 1, 2, 3])
480
@pytest.mark.parametrize("maintain_order", [True, False])
481
def test_add_aggstates_in_binary_expr_24504(
482
lhs: tuple[pl.Expr, bool, list[int]],
483
rhs: tuple[pl.Expr, bool, list[int]],
484
n_rows: int,
485
maintain_order: bool,
486
) -> None:
487
df = pl.DataFrame({"g": [10, 10, 20], "n": [2, 1, 3]})
488
lf = df.head(n_rows).lazy()
489
expr = pl.Expr.add(lhs[0].alias("lhs"), rhs[0].alias("rhs")).alias("expr")
490
q = lf.group_by("g", maintain_order=maintain_order).agg(expr)
491
out = q.collect()
492
493
# check schema
494
assert q.collect_schema() == out.schema
495
496
# check output against ground truth
497
if n_rows in [1, 2, 3]:
498
data = df.to_dict(as_series=False)
499
result: dict[int, Any] = {}
500
for gg, ll, rr in zip(
501
data["g"][:n_rows], lhs[2][:n_rows], rhs[2][:n_rows], strict=True
502
):
503
result.setdefault(gg, []).append(ll + rr)
504
if lhs[1] and rhs[1]:
505
# expect scalar result
506
result = {k: v[0] for k, v in result.items()}
507
expected = pl.DataFrame(
508
{"g": list(result.keys()), "expr": list(result.values())}
509
)
510
assert_frame_equal(out, expected, check_row_order=maintain_order)
511
512
# check output against non_aggregated expression evaluation
513
if n_rows in [1, 2, 3]:
514
print(f"df\n{df}")
515
grouped = df.head(n_rows).group_by("g", maintain_order=maintain_order)
516
out_non_agg = pl.DataFrame({})
517
for df_group in grouped:
518
df = df_group[1]
519
print(f"df pre expr:\n{df}", flush=True)
520
if lhs[1] and rhs[1]:
521
df = df.head(1)
522
df = df.select(["g", expr])
523
else:
524
df = df.select(["g", expr.implode()]).head(1)
525
print(f"df post expr:{df}\n")
526
out_non_agg = out_non_agg.vstack(df)
527
print(f"out_non_agg:\n{out_non_agg}")
528
529
assert_frame_equal(out, out_non_agg, check_row_order=maintain_order)
530
531
532
# parametric tuples: (expr, is_scalar)
533
agg_expressions_sort = [
534
(pl.lit(7, pl.Int64), True), # LiteralScalar
535
(pl.col("n"), False), # NotAggregated
536
(pl.col("n").sort(), False), # NotAggregated w groups modified
537
(pl.int_range(pl.len()), False), # AggregatedList
538
(pl.int_range(pl.len()).reverse(), False), # AggregatedList w groups modified
539
(pl.col("n").first(), True), # AggregatedScalar
540
]
541
542
543
@pytest.mark.parametrize("lhs", agg_expressions_sort)
544
@pytest.mark.parametrize("rhs", agg_expressions_sort)
545
@pytest.mark.parametrize("maintain_order", [True, False])
546
def test_add_aggstates_with_sort_in_binary_expr_24504(
547
lhs: tuple[pl.Expr, bool, list[int]],
548
rhs: tuple[pl.Expr, bool, list[int]],
549
maintain_order: bool,
550
) -> None:
551
df = pl.DataFrame({"g": [10, 10, 20], "n": [2, 1, 3]})
552
lf = df.lazy()
553
expr = pl.Expr.add(lhs[0].alias("lhs"), rhs[0].alias("rhs")).alias("expr")
554
q = lf.group_by("g", maintain_order=maintain_order).agg(expr)
555
out = q.collect()
556
557
# check schema
558
assert q.collect_schema() == out.schema
559
560
# check output against non_aggregated expression evaluation
561
grouped = df.group_by("g", maintain_order=maintain_order)
562
out_non_agg = pl.DataFrame({})
563
for df_group in grouped:
564
df = df_group[1]
565
if lhs[1] and rhs[1]:
566
df = df.head(1)
567
df = df.select(["g", expr])
568
else:
569
df = df.select(["g", expr.implode()]).head(1)
570
out_non_agg = out_non_agg.vstack(df)
571
572
assert_frame_equal(out, out_non_agg, check_row_order=maintain_order)
573
574
575
@pytest.mark.parametrize("maintain_order", [True, False])
576
def test_binary_context_nested(maintain_order: bool) -> None:
577
df = pl.DataFrame({"groups": [1, 1, 2, 2, 3, 3], "vals": [1, 13, 3, 87, 1, 6]})
578
out = (
579
df.lazy()
580
.group_by(pl.col("groups"), maintain_order=maintain_order)
581
.agg(
582
[
583
pl.when(pl.col("vals").eq(pl.lit(1)))
584
.then(pl.col("vals").sum())
585
.otherwise(pl.lit(90))
586
.alias("vals")
587
]
588
)
589
).collect()
590
expected = pl.DataFrame(
591
{"groups": [1, 2, 3], "vals": [[14, 90], [90, 90], [7, 90]]}
592
)
593
assert_frame_equal(out, expected, check_row_order=maintain_order)
594
595
596
def test_get() -> None:
597
# N binary, scalar index (N to 1).
598
df = pl.DataFrame({"a": [b"\x01\x02\x03", b"", b"\x04\x05"]})
599
result = df.select(pl.col("a").bin.get(0, null_on_oob=True))
600
expected = pl.DataFrame({"a": [1, None, 4]}, schema={"a": pl.UInt8})
601
assert_frame_equal(result, expected)
602
603
# Negative index.
604
result = df.select(pl.col("a").bin.get(-1, null_on_oob=True))
605
expected = pl.DataFrame({"a": [3, None, 5]}, schema={"a": pl.UInt8})
606
assert_frame_equal(result, expected)
607
608
# Null index.
609
result = df.select(
610
pl.col("a").bin.get(pl.lit(None, dtype=pl.Int64), null_on_oob=True)
611
)
612
expected = pl.DataFrame({"a": [None, None, None]}, schema={"a": pl.UInt8})
613
assert_frame_equal(result, expected)
614
615
# N binary, N indices (N to N).
616
df = pl.DataFrame(
617
{
618
"a": [b"\x01\x02\x03", b"\x04\x05", b"\x06"],
619
"idx": [2, 0, 0],
620
}
621
)
622
result = df.select(pl.col("a").bin.get(pl.col("idx"), null_on_oob=True))
623
expected = pl.DataFrame({"a": [3, 4, 6]}, schema={"a": pl.UInt8})
624
assert_frame_equal(result, expected)
625
626
# 1 binary, N indices (1 to N).
627
result = pl.select(
628
pl.lit(pl.Series("a", [b"\x01\x02\x03"])).bin.get(
629
pl.Series("idx", [0, 1, 2]), null_on_oob=True
630
)
631
)
632
expected = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8})
633
assert_frame_equal(result, expected)
634
635
# OOB raises error.
636
df = pl.DataFrame({"a": [b"\x01\x02"]})
637
with pytest.raises(pl.exceptions.ComputeError, match="out of bounds"):
638
df.select(pl.col("a").bin.get(5))
639
640