Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py
8431 views
1
from __future__ import annotations
2
3
import datetime as dt
4
import json
5
import math
6
import re
7
from datetime import date, datetime
8
from functools import partial
9
from math import cosh
10
from typing import TYPE_CHECKING, Any, Literal
11
12
import numpy as np
13
import pytest
14
15
import polars as pl
16
from polars._utils.udfs import _BYTECODE_PARSER_CACHE_, _NUMPY_FUNCTIONS, BytecodeParser
17
from polars._utils.various import in_terminal_that_supports_colour
18
from polars.exceptions import PolarsInefficientMapWarning
19
from polars.testing import assert_frame_equal, assert_series_equal
20
21
if TYPE_CHECKING:
22
from collections.abc import Callable
23
24
MY_CONSTANT = 3
25
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
26
MY_LIST = [1, 2, 3]
27
28
# column_name, function, expected_suggestion
29
TEST_CASES = [
30
# ---------------------------------------------
31
# numeric expr: math, comparison, logic ops
32
# ---------------------------------------------
33
("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666', None),
34
("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2', None),
35
("a", "lambda x: x & True", 'pl.col("a") & True', None),
36
("a", "lambda x: x | False", 'pl.col("a") | False', None),
37
("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3', None),
38
("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1', None),
39
(
40
"a",
41
"lambda x: not (x > 1) or x == 2",
42
'~(pl.col("a") > 1) | (pl.col("a") == 2)',
43
None,
44
),
45
("a", "lambda x: x is None", 'pl.col("a") is None', None),
46
("a", "lambda x: x is not None", 'pl.col("a") is not None', None),
47
(
48
"a",
49
"lambda x: ((x * -x) ** x) * 1.0",
50
'((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0',
51
None,
52
),
53
(
54
"a",
55
"lambda x: 1.0 * (x * (x**x))",
56
'1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))',
57
None,
58
),
59
(
60
"a",
61
"lambda x: (x / x) + ((x * x) - x)",
62
'(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))',
63
None,
64
),
65
(
66
"a",
67
"lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))",
68
'(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))',
69
None,
70
),
71
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None),
72
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None),
73
(
74
"a",
75
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
76
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
77
None,
78
),
79
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None),
80
(
81
"a",
82
"lambda x: (float(x) * int(x)) // 2",
83
'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',
84
None,
85
),
86
(
87
"a",
88
"lambda x: 1 / (1 + np.exp(-x))",
89
'1 / (1 + (-pl.col("a")).exp())',
90
None,
91
),
92
# ---------------------------------------------
93
# math module
94
# ---------------------------------------------
95
("e", "lambda x: math.asin(x)", 'pl.col("e").arcsin()', None),
96
("e", "lambda x: math.asinh(x)", 'pl.col("e").arcsinh()', None),
97
("e", "lambda x: math.atan(x)", 'pl.col("e").arctan()', None),
98
("e", "lambda x: math.atanh(x)", 'pl.col("e").arctanh()', "self"),
99
("e", "lambda x: math.cos(x)", 'pl.col("e").cos()', None),
100
("e", "lambda x: math.degrees(x)", 'pl.col("e").degrees()', None),
101
("e", "lambda x: math.exp(x)", 'pl.col("e").exp()', None),
102
("e", "lambda x: math.log(x)", 'pl.col("e").log()', None),
103
("e", "lambda x: math.log10(x)", 'pl.col("e").log10()', None),
104
("e", "lambda x: math.log1p(x)", 'pl.col("e").log1p()', None),
105
("e", "lambda x: math.radians(x)", 'pl.col("e").radians()', None),
106
("e", "lambda x: math.sin(x)", 'pl.col("e").sin()', None),
107
("e", "lambda x: math.sinh(x)", 'pl.col("e").sinh()', None),
108
("e", "lambda x: math.sqrt(x)", 'pl.col("e").sqrt()', None),
109
("e", "lambda x: math.tan(x)", 'pl.col("e").tan()', None),
110
("e", "lambda x: math.tanh(x)", 'pl.col("e").tanh()', None),
111
# ---------------------------------------------
112
# numpy module
113
# ---------------------------------------------
114
("e", "lambda x: np.arccos(x)", 'pl.col("e").arccos()', None),
115
("e", "lambda x: np.arccosh(x)", 'pl.col("e").arccosh()', None),
116
("e", "lambda x: np.arcsin(x)", 'pl.col("e").arcsin()', None),
117
("e", "lambda x: np.arcsinh(x)", 'pl.col("e").arcsinh()', None),
118
("e", "lambda x: np.arctan(x)", 'pl.col("e").arctan()', None),
119
("e", "lambda x: np.arctanh(x)", 'pl.col("e").arctanh()', "self"),
120
("a", "lambda x: 0 + np.cbrt(x)", '0 + pl.col("a").cbrt()', None),
121
("e", "lambda x: np.ceil(x)", 'pl.col("e").ceil()', None),
122
("e", "lambda x: np.cos(x)", 'pl.col("e").cos()', None),
123
("e", "lambda x: np.cosh(x)", 'pl.col("e").cosh()', None),
124
("e", "lambda x: np.degrees(x)", 'pl.col("e").degrees()', None),
125
("e", "lambda x: np.exp(x)", 'pl.col("e").exp()', None),
126
("e", "lambda x: np.floor(x)", 'pl.col("e").floor()', None),
127
("e", "lambda x: np.log(x)", 'pl.col("e").log()', None),
128
("e", "lambda x: np.log10(x)", 'pl.col("e").log10()', None),
129
("e", "lambda x: np.log1p(x)", 'pl.col("e").log1p()', None),
130
("e", "lambda x: np.radians(x)", 'pl.col("e").radians()', None),
131
("a", "lambda x: np.sign(x)", 'pl.col("a").sign()', None),
132
("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1', None),
133
(
134
"a", # note: functions operate on consts
135
"lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)",
136
'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',
137
None,
138
),
139
("a", "lambda x: np.sinh(x) + 1", 'pl.col("a").sinh() + 1', None),
140
("a", "lambda x: np.sqrt(x) + 1", 'pl.col("a").sqrt() + 1', None),
141
("a", "lambda x: np.tan(x) + 1", 'pl.col("a").tan() + 1', None),
142
("e", "lambda x: np.tanh(x)", 'pl.col("e").tanh()', None),
143
# ---------------------------------------------
144
# logical 'and/or' (validate nesting levels)
145
# ---------------------------------------------
146
(
147
"a",
148
"lambda x: x > 1 or (x == 1 and x == 2)",
149
'(pl.col("a") > 1) | ((pl.col("a") == 1) & (pl.col("a") == 2))',
150
None,
151
),
152
(
153
"a",
154
"lambda x: (x > 1 or x == 1) and x == 2",
155
'((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)',
156
None,
157
),
158
(
159
"a",
160
"lambda x: x > 2 or x != 3 and x not in (0, 1, 4)",
161
'(pl.col("a") > 2) | ((pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4)))',
162
None,
163
),
164
(
165
"a",
166
"lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3",
167
'((pl.col("a") > 1) & (pl.col("a") != 2)) | (((pl.col("a") % 2) == 0) & (pl.col("a") < 3))',
168
None,
169
),
170
(
171
"a",
172
"lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3",
173
'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',
174
None,
175
),
176
# ---------------------------------------------
177
# string exprs
178
# ---------------------------------------------
179
(
180
"b",
181
"lambda x: str(x).title()",
182
'pl.col("b").cast(pl.String).str.to_titlecase()',
183
None,
184
),
185
(
186
"b",
187
'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',
188
'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',
189
None,
190
),
191
(
192
"b",
193
"lambda x: x.strip().startswith('#')",
194
"""pl.col("b").str.strip_chars().str.starts_with('#')""",
195
None,
196
),
197
(
198
"b",
199
"""lambda x: x.rstrip().endswith(('!','#','?','"'))""",
200
"""pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""",
201
None,
202
),
203
(
204
"b",
205
"""lambda x: x.lstrip().startswith(('!','#','?',"'"))""",
206
"""pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""",
207
None,
208
),
209
(
210
"b",
211
"lambda x: x.replace(':','')",
212
"""pl.col("b").str.replace_all(':','',literal=True)""",
213
None,
214
),
215
(
216
"b",
217
"lambda x: x.replace(':','',2)",
218
"""pl.col("b").str.replace(':','',n=2,literal=True)""",
219
None,
220
),
221
(
222
"b",
223
"lambda x: x.removeprefix('A').removesuffix('F')",
224
"""pl.col("b").str.strip_prefix('A').str.strip_suffix('F')""",
225
None,
226
),
227
(
228
"b",
229
"lambda x: x.zfill(8)",
230
"""pl.col("b").str.zfill(8)""",
231
None,
232
),
233
# ---------------------------------------------
234
# replace
235
# ---------------------------------------------
236
("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace_strict(MY_DICT)', None),
237
(
238
"a",
239
"lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]",
240
'(pl.col("a") - 1).replace_strict(MY_DICT) + (1 + pl.col("a")).replace_strict(MY_DICT)',
241
None,
242
),
243
# ---------------------------------------------
244
# standard library datetime parsing
245
# ---------------------------------------------
246
(
247
"d",
248
'lambda x: datetime.strptime(x, "%Y-%m-%d")',
249
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
250
pl.Datetime("us"),
251
),
252
(
253
"d",
254
'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")',
255
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
256
pl.Datetime("us"),
257
),
258
# ---------------------------------------------
259
# temporal attributes/methods
260
# ---------------------------------------------
261
(
262
"f",
263
"lambda x: x.isoweekday()",
264
'pl.col("f").dt.weekday()',
265
None,
266
),
267
(
268
"f",
269
"lambda x: x.hour + x.minute + x.second",
270
'(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()',
271
None,
272
),
273
# ---------------------------------------------
274
# Bitwise shifts
275
# ---------------------------------------------
276
(
277
"a",
278
"lambda x: (3 << (30-x)) & 3",
279
'(3 * 2**(30 - pl.col("a"))).cast(pl.Int64) & 3',
280
None,
281
),
282
(
283
"a",
284
"lambda x: (x << 32) & 3",
285
'(pl.col("a") * 2**32).cast(pl.Int64) & 3',
286
None,
287
),
288
(
289
"a",
290
"lambda x: ((32-x) >> (3)) & 3",
291
'((32 - pl.col("a")) / 2**3).cast(pl.Int64) & 3',
292
None,
293
),
294
(
295
"a",
296
"lambda x: (32 >> (3-x)) & 3",
297
'(32 / 2**(3 - pl.col("a"))).cast(pl.Int64) & 3',
298
None,
299
),
300
]
301
302
NOOP_TEST_CASES = [
303
"lambda x: x",
304
"lambda x, y: x + y",
305
"lambda x: x[0] + 1",
306
"lambda x: MY_LIST[x]",
307
"lambda x: MY_DICT[1]",
308
'lambda x: "first" if x == 1 else "not first"',
309
'lambda x: np.sign(x, casting="unsafe")',
310
]
311
312
EVAL_ENVIRONMENT = {
313
"MY_CONSTANT": MY_CONSTANT,
314
"MY_DICT": MY_DICT,
315
"MY_LIST": MY_LIST,
316
"cosh": cosh,
317
"datetime": datetime,
318
"dt": dt,
319
"math": math,
320
"np": np,
321
"pl": pl,
322
}
323
324
325
@pytest.mark.parametrize(
326
"func",
327
NOOP_TEST_CASES,
328
)
329
def test_parse_invalid_function(func: str) -> None:
330
# functions we don't (yet?) offer suggestions for
331
parser = BytecodeParser(eval(func), map_target="expr")
332
assert not parser.can_attempt_rewrite() or not parser.to_expression("x")
333
334
335
@pytest.mark.parametrize(
336
("col", "func", "expr_repr", "dtype"),
337
TEST_CASES,
338
)
339
@pytest.mark.filterwarnings(
340
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
341
"ignore:invalid value encountered:RuntimeWarning",
342
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",
343
)
344
@pytest.mark.may_fail_auto_streaming # dtype not set
345
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
346
def test_parse_apply_functions(
347
col: str, func: str, expr_repr: str, dtype: Literal["self"] | pl.DataType | None
348
) -> None:
349
return_dtype: pl.DataTypeExpr | None = None
350
if dtype == "self":
351
return_dtype = pl.self_dtype()
352
elif dtype is None:
353
return_dtype = None
354
else:
355
return_dtype = dtype.to_dtype_expr() # type: ignore[union-attr]
356
357
parser = BytecodeParser(eval(func), map_target="expr")
358
suggested_expression = parser.to_expression(col)
359
assert suggested_expression == expr_repr
360
361
df = pl.DataFrame(
362
{
363
"a": [1, 2, 3],
364
"b": ["AB", "cd", "eF"],
365
"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],
366
"d": ["2020-01-01", "2020-01-02", "2020-01-03"],
367
"e": [0.5, 0.4, 0.1],
368
"f": [
369
datetime(1969, 12, 31),
370
datetime(2024, 5, 6),
371
datetime(2077, 10, 20),
372
],
373
}
374
)
375
376
result_frame = df.select(
377
x=col,
378
y=eval(suggested_expression, EVAL_ENVIRONMENT),
379
)
380
with pytest.warns(
381
PolarsInefficientMapWarning,
382
match=r"(?s)Expr\.map_elements.*with this one instead",
383
):
384
expected_frame = df.select(
385
x=pl.col(col),
386
y=pl.col(col).map_elements(eval(func), return_dtype=return_dtype),
387
)
388
assert_frame_equal(
389
result_frame,
390
expected_frame,
391
check_dtypes=(".dt." not in suggested_expression),
392
)
393
394
395
@pytest.mark.filterwarnings(
396
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
397
"ignore:invalid value encountered:RuntimeWarning",
398
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",
399
)
400
@pytest.mark.may_fail_auto_streaming # dtype is not set
401
def test_parse_apply_raw_functions() -> None:
402
lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]})
403
404
# test bare 'numpy' functions
405
for func_name in _NUMPY_FUNCTIONS:
406
func = getattr(np, func_name)
407
408
# note: we can't parse/rewrite raw numpy functions...
409
parser = BytecodeParser(func, map_target="expr")
410
assert not parser.can_attempt_rewrite()
411
412
# ...but we ARE still able to warn
413
with pytest.warns(
414
PolarsInefficientMapWarning,
415
match=rf"(?s)Expr\.map_elements.*Replace this expression.*np\.{func_name}",
416
):
417
df1 = lf.select(
418
pl.col("a").map_elements(func, return_dtype=pl.self_dtype())
419
).collect()
420
df2 = lf.select(getattr(pl.col("a"), func_name)()).collect()
421
assert_frame_equal(df1, df2)
422
423
# test bare 'json.loads'
424
json_dtype = pl.Struct({"a": pl.Int64, "b": pl.Boolean, "c": pl.String})
425
expr_native = pl.col("value").str.json_decode(json_dtype)
426
with pytest.warns(
427
PolarsInefficientMapWarning,
428
match=r"(?s)Expr\.map_elements.*with this one instead:.*\.str\.json_decode",
429
):
430
expr_pyfunc = pl.col("value").map_elements(json.loads, return_dtype=json_dtype)
431
432
result_frames = [
433
pl.LazyFrame({"value": ['{"a":1, "b": true, "c": "xx"}', None]})
434
.select(extracted=expr)
435
.unnest("extracted")
436
.collect()
437
for expr in (expr_native, expr_pyfunc)
438
]
439
assert_frame_equal(*result_frames)
440
441
# test primitive python casts
442
for py_cast, pl_dtype in ((str, pl.String), (int, pl.Int64), (float, pl.Float64)):
443
with pytest.warns(
444
PolarsInefficientMapWarning,
445
match=rf'(?s)with this one instead.*pl\.col\("a"\)\.cast\(pl\.{pl_dtype.__name__}\)',
446
):
447
assert_frame_equal(
448
lf.select(
449
pl.col("a").map_elements(py_cast, return_dtype=pl_dtype)
450
).collect(),
451
lf.select(pl.col("a").cast(pl_dtype)).collect(),
452
)
453
454
455
def test_parse_apply_miscellaneous() -> None:
456
# note: can also identify inefficient functions and methods as well as lambdas
457
class Test:
458
def x10(self, x: float) -> float:
459
return x * 10
460
461
def mcosh(self, x: float) -> float:
462
return cosh(x)
463
464
parser = BytecodeParser(Test().x10, map_target="expr")
465
suggested_expression = parser.to_expression(col="colx")
466
assert suggested_expression == 'pl.col("colx") * 10'
467
468
with pytest.warns(
469
PolarsInefficientMapWarning,
470
match=r"(?s)Series\.map_elements.*with this one instead.*s\.cosh\(\)",
471
):
472
pl.Series("colx", [0.5, 0.25]).map_elements(
473
function=Test().mcosh,
474
return_dtype=pl.Float64,
475
)
476
477
# note: all constants - should not create a warning/suggestion
478
suggested_expression = BytecodeParser(
479
lambda x: MY_CONSTANT + 42, map_target="expr"
480
).to_expression(col="colx")
481
assert suggested_expression is None
482
483
# literals as method parameters
484
s = pl.Series("srs", [0, 1, 2, 3, 4])
485
with pytest.warns(
486
PolarsInefficientMapWarning,
487
match=r"(?s)Series\.map_elements.*with this one instead.*\(np\.cos\(3\) \+ s\) - abs\(-1\)",
488
):
489
assert_series_equal(
490
s.map_elements(lambda x: np.cos(3) + x - abs(-1), return_dtype=pl.Float64),
491
np.cos(3) + s - 1,
492
)
493
494
# if 's' is already the name of a global variable then the series alias
495
# used in the user warning will fall back (in priority order) through
496
# various aliases until it finds one that is available.
497
s, srs, series = -1, 0, 1 # type: ignore[assignment]
498
expr1 = BytecodeParser(lambda x: x + s, map_target="series")
499
expr2 = BytecodeParser(lambda x: srs + x + s, map_target="series")
500
expr3 = BytecodeParser(lambda x: srs + x + s - x + series, map_target="series")
501
502
assert expr1.to_expression(col="srs") == "srs + s"
503
assert expr2.to_expression(col="srs") == "(srs + series) + s"
504
assert expr3.to_expression(col="srs") == "(((srs + srs0) + s) - srs0) + series"
505
506
507
@pytest.mark.parametrize(
508
("name", "data", "func", "expr_repr"),
509
[
510
(
511
"srs",
512
[1, 2, 3],
513
lambda x: str(x),
514
"s.cast(pl.String)",
515
),
516
(
517
"s",
518
[date(2077, 10, 10), date(1999, 12, 31)],
519
lambda d: d.month,
520
"s.dt.month()",
521
),
522
(
523
"",
524
[-20, -12, -5, 0, 5, 12, 20],
525
lambda x: (abs(x) != 12) and (x > 10 or x < -10 or x == 0),
526
"(s.abs() != 12) & ((s > 10) | (s < -10) | (s == 0))",
527
),
528
],
529
)
530
@pytest.mark.filterwarnings(
531
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning"
532
)
533
def test_parse_apply_series(
534
name: str, data: list[Any], func: Callable[[Any], Any], expr_repr: str
535
) -> None:
536
# expression/series generate same warning, with 's' as the series placeholder
537
s = pl.Series(name, data)
538
539
parser = BytecodeParser(func, map_target="series")
540
suggested_expression = parser.to_expression(s.name)
541
assert suggested_expression == expr_repr
542
543
with pytest.warns(
544
PolarsInefficientMapWarning,
545
match=r"(?s)Series\.map_elements.*s\.\w+\(",
546
):
547
expected_series = s.map_elements(func)
548
549
result_series = eval(suggested_expression)
550
assert_series_equal(expected_series, result_series, check_dtypes=False)
551
552
553
@pytest.mark.may_fail_auto_streaming
554
def test_expr_exact_warning_message() -> None:
555
red, green, end_escape = (
556
("\x1b[31m", "\x1b[32m", "\x1b[0m")
557
if in_terminal_that_supports_colour()
558
else ("", "", "")
559
)
560
msg = re.escape(
561
"\n"
562
"Expr.map_elements is significantly slower than the native expressions API.\n"
563
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
564
"Replace this expression...\n"
565
f' {red}- pl.col("a").map_elements(lambda x: ...){end_escape}\n'
566
"with this one instead:\n"
567
f' {green}+ pl.col("a") + 1{end_escape}\n'
568
)
569
570
fn = lambda x: x + 1 # noqa: E731
571
df = pl.DataFrame({"a": [1, 2, 3]})
572
573
# check the EXACT warning messages - if modifying the message in the future,
574
# make sure to keep the `^` and `$`, and the assertion on `len(warnings)`
575
with pytest.warns( # noqa: PT031
576
PolarsInefficientMapWarning,
577
match=rf"^{msg}$",
578
) as warnings:
579
for _ in range(3): # << loop a few times to exercise the caching path
580
df.select(pl.col("a").map_elements(fn, return_dtype=pl.Int64))
581
582
assert len(warnings) == 3
583
584
# confirm that the associated parser/etc was cached
585
bp = _BYTECODE_PARSER_CACHE_[(fn, "expr")]
586
assert isinstance(bp, BytecodeParser)
587
assert bp.to_expression("a") == 'pl.col("a") + 1'
588
589
590
def test_omit_implicit_bool() -> None:
591
parser = BytecodeParser(
592
function=lambda x: x and x and x.date(),
593
map_target="expr",
594
)
595
suggested_expression = parser.to_expression("d")
596
assert suggested_expression == 'pl.col("d").dt.date()'
597
598
599
def test_partial_functions_13523() -> None:
600
def plus(value: int, amount: int) -> int:
601
return value + amount
602
603
data = {"a": [1, 2], "b": [3, 4]}
604
df = pl.DataFrame(data)
605
# should not warn
606
_ = df["a"].map_elements(partial(plus, amount=1))
607
608