Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py
6940 views
1
from __future__ import annotations
2
3
import datetime as dt
4
import json
5
import math
6
import re
7
from datetime import date, datetime
8
from functools import partial
9
from math import cosh
10
from typing import Any, Callable, Literal
11
12
import numpy as np
13
import pytest
14
15
import polars as pl
16
from polars._utils.udfs import _BYTECODE_PARSER_CACHE_, _NUMPY_FUNCTIONS, BytecodeParser
17
from polars._utils.various import in_terminal_that_supports_colour
18
from polars.exceptions import PolarsInefficientMapWarning
19
from polars.testing import assert_frame_equal, assert_series_equal
20
21
MY_CONSTANT = 3
22
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
23
MY_LIST = [1, 2, 3]
24
25
# column_name, function, expected_suggestion
26
TEST_CASES = [
27
# ---------------------------------------------
28
# numeric expr: math, comparison, logic ops
29
# ---------------------------------------------
30
("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666', None),
31
("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2', None),
32
("a", "lambda x: x & True", 'pl.col("a") & True', None),
33
("a", "lambda x: x | False", 'pl.col("a") | False', None),
34
("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3', None),
35
("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1', None),
36
(
37
"a",
38
"lambda x: not (x > 1) or x == 2",
39
'~(pl.col("a") > 1) | (pl.col("a") == 2)',
40
None,
41
),
42
("a", "lambda x: x is None", 'pl.col("a") is None', None),
43
("a", "lambda x: x is not None", 'pl.col("a") is not None', None),
44
(
45
"a",
46
"lambda x: ((x * -x) ** x) * 1.0",
47
'((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0',
48
None,
49
),
50
(
51
"a",
52
"lambda x: 1.0 * (x * (x**x))",
53
'1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))',
54
None,
55
),
56
(
57
"a",
58
"lambda x: (x / x) + ((x * x) - x)",
59
'(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))',
60
None,
61
),
62
(
63
"a",
64
"lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))",
65
'(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))',
66
None,
67
),
68
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None),
69
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None),
70
(
71
"a",
72
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
73
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
74
None,
75
),
76
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None),
77
(
78
"a",
79
"lambda x: (float(x) * int(x)) // 2",
80
'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',
81
None,
82
),
83
(
84
"a",
85
"lambda x: 1 / (1 + np.exp(-x))",
86
'1 / (1 + (-pl.col("a")).exp())',
87
None,
88
),
89
# ---------------------------------------------
90
# math module
91
# ---------------------------------------------
92
("e", "lambda x: math.asin(x)", 'pl.col("e").arcsin()', None),
93
("e", "lambda x: math.asinh(x)", 'pl.col("e").arcsinh()', None),
94
("e", "lambda x: math.atan(x)", 'pl.col("e").arctan()', None),
95
("e", "lambda x: math.atanh(x)", 'pl.col("e").arctanh()', "self"),
96
("e", "lambda x: math.cos(x)", 'pl.col("e").cos()', None),
97
("e", "lambda x: math.degrees(x)", 'pl.col("e").degrees()', None),
98
("e", "lambda x: math.exp(x)", 'pl.col("e").exp()', None),
99
("e", "lambda x: math.log(x)", 'pl.col("e").log()', None),
100
("e", "lambda x: math.log10(x)", 'pl.col("e").log10()', None),
101
("e", "lambda x: math.log1p(x)", 'pl.col("e").log1p()', None),
102
("e", "lambda x: math.radians(x)", 'pl.col("e").radians()', None),
103
("e", "lambda x: math.sin(x)", 'pl.col("e").sin()', None),
104
("e", "lambda x: math.sinh(x)", 'pl.col("e").sinh()', None),
105
("e", "lambda x: math.sqrt(x)", 'pl.col("e").sqrt()', None),
106
("e", "lambda x: math.tan(x)", 'pl.col("e").tan()', None),
107
("e", "lambda x: math.tanh(x)", 'pl.col("e").tanh()', None),
108
# ---------------------------------------------
109
# numpy module
110
# ---------------------------------------------
111
("e", "lambda x: np.arccos(x)", 'pl.col("e").arccos()', None),
112
("e", "lambda x: np.arccosh(x)", 'pl.col("e").arccosh()', None),
113
("e", "lambda x: np.arcsin(x)", 'pl.col("e").arcsin()', None),
114
("e", "lambda x: np.arcsinh(x)", 'pl.col("e").arcsinh()', None),
115
("e", "lambda x: np.arctan(x)", 'pl.col("e").arctan()', None),
116
("e", "lambda x: np.arctanh(x)", 'pl.col("e").arctanh()', "self"),
117
("a", "lambda x: 0 + np.cbrt(x)", '0 + pl.col("a").cbrt()', None),
118
("e", "lambda x: np.ceil(x)", 'pl.col("e").ceil()', None),
119
("e", "lambda x: np.cos(x)", 'pl.col("e").cos()', None),
120
("e", "lambda x: np.cosh(x)", 'pl.col("e").cosh()', None),
121
("e", "lambda x: np.degrees(x)", 'pl.col("e").degrees()', None),
122
("e", "lambda x: np.exp(x)", 'pl.col("e").exp()', None),
123
("e", "lambda x: np.floor(x)", 'pl.col("e").floor()', None),
124
("e", "lambda x: np.log(x)", 'pl.col("e").log()', None),
125
("e", "lambda x: np.log10(x)", 'pl.col("e").log10()', None),
126
("e", "lambda x: np.log1p(x)", 'pl.col("e").log1p()', None),
127
("e", "lambda x: np.radians(x)", 'pl.col("e").radians()', None),
128
("a", "lambda x: np.sign(x)", 'pl.col("a").sign()', None),
129
("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1', None),
130
(
131
"a", # note: functions operate on consts
132
"lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)",
133
'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',
134
None,
135
),
136
("a", "lambda x: np.sinh(x) + 1", 'pl.col("a").sinh() + 1', None),
137
("a", "lambda x: np.sqrt(x) + 1", 'pl.col("a").sqrt() + 1', None),
138
("a", "lambda x: np.tan(x) + 1", 'pl.col("a").tan() + 1', None),
139
("e", "lambda x: np.tanh(x)", 'pl.col("e").tanh()', None),
140
# ---------------------------------------------
141
# logical 'and/or' (validate nesting levels)
142
# ---------------------------------------------
143
(
144
"a",
145
"lambda x: x > 1 or (x == 1 and x == 2)",
146
'(pl.col("a") > 1) | ((pl.col("a") == 1) & (pl.col("a") == 2))',
147
None,
148
),
149
(
150
"a",
151
"lambda x: (x > 1 or x == 1) and x == 2",
152
'((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)',
153
None,
154
),
155
(
156
"a",
157
"lambda x: x > 2 or x != 3 and x not in (0, 1, 4)",
158
'(pl.col("a") > 2) | ((pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4)))',
159
None,
160
),
161
(
162
"a",
163
"lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3",
164
'((pl.col("a") > 1) & (pl.col("a") != 2)) | (((pl.col("a") % 2) == 0) & (pl.col("a") < 3))',
165
None,
166
),
167
(
168
"a",
169
"lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3",
170
'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',
171
None,
172
),
173
# ---------------------------------------------
174
# string exprs
175
# ---------------------------------------------
176
(
177
"b",
178
"lambda x: str(x).title()",
179
'pl.col("b").cast(pl.String).str.to_titlecase()',
180
None,
181
),
182
(
183
"b",
184
'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',
185
'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',
186
None,
187
),
188
(
189
"b",
190
"lambda x: x.strip().startswith('#')",
191
"""pl.col("b").str.strip_chars().str.starts_with('#')""",
192
None,
193
),
194
(
195
"b",
196
"""lambda x: x.rstrip().endswith(('!','#','?','"'))""",
197
"""pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""",
198
None,
199
),
200
(
201
"b",
202
"""lambda x: x.lstrip().startswith(('!','#','?',"'"))""",
203
"""pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""",
204
None,
205
),
206
(
207
"b",
208
"lambda x: x.replace(':','')",
209
"""pl.col("b").str.replace_all(':','',literal=True)""",
210
None,
211
),
212
(
213
"b",
214
"lambda x: x.replace(':','',2)",
215
"""pl.col("b").str.replace(':','',n=2,literal=True)""",
216
None,
217
),
218
(
219
"b",
220
"lambda x: x.removeprefix('A').removesuffix('F')",
221
"""pl.col("b").str.strip_prefix('A').str.strip_suffix('F')""",
222
None,
223
),
224
(
225
"b",
226
"lambda x: x.zfill(8)",
227
"""pl.col("b").str.zfill(8)""",
228
None,
229
),
230
# ---------------------------------------------
231
# replace
232
# ---------------------------------------------
233
("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace_strict(MY_DICT)', None),
234
(
235
"a",
236
"lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]",
237
'(pl.col("a") - 1).replace_strict(MY_DICT) + (1 + pl.col("a")).replace_strict(MY_DICT)',
238
None,
239
),
240
# ---------------------------------------------
241
# standard library datetime parsing
242
# ---------------------------------------------
243
(
244
"d",
245
'lambda x: datetime.strptime(x, "%Y-%m-%d")',
246
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
247
pl.Datetime("us"),
248
),
249
(
250
"d",
251
'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")',
252
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
253
pl.Datetime("us"),
254
),
255
# ---------------------------------------------
256
# temporal attributes/methods
257
# ---------------------------------------------
258
(
259
"f",
260
"lambda x: x.isoweekday()",
261
'pl.col("f").dt.weekday()',
262
None,
263
),
264
(
265
"f",
266
"lambda x: x.hour + x.minute + x.second",
267
'(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()',
268
None,
269
),
270
# ---------------------------------------------
271
# Bitwise shifts
272
# ---------------------------------------------
273
(
274
"a",
275
"lambda x: (3 << (30-x)) & 3",
276
'(3 * 2**(30 - pl.col("a"))).cast(pl.Int64) & 3',
277
None,
278
),
279
(
280
"a",
281
"lambda x: (x << 32) & 3",
282
'(pl.col("a") * 2**32).cast(pl.Int64) & 3',
283
None,
284
),
285
(
286
"a",
287
"lambda x: ((32-x) >> (3)) & 3",
288
'((32 - pl.col("a")) / 2**3).cast(pl.Int64) & 3',
289
None,
290
),
291
(
292
"a",
293
"lambda x: (32 >> (3-x)) & 3",
294
'(32 / 2**(3 - pl.col("a"))).cast(pl.Int64) & 3',
295
None,
296
),
297
]
298
299
NOOP_TEST_CASES = [
300
"lambda x: x",
301
"lambda x, y: x + y",
302
"lambda x: x[0] + 1",
303
"lambda x: MY_LIST[x]",
304
"lambda x: MY_DICT[1]",
305
'lambda x: "first" if x == 1 else "not first"',
306
'lambda x: np.sign(x, casting="unsafe")',
307
]
308
309
EVAL_ENVIRONMENT = {
310
"MY_CONSTANT": MY_CONSTANT,
311
"MY_DICT": MY_DICT,
312
"MY_LIST": MY_LIST,
313
"cosh": cosh,
314
"datetime": datetime,
315
"dt": dt,
316
"math": math,
317
"np": np,
318
"pl": pl,
319
}
320
321
322
@pytest.mark.parametrize(
323
"func",
324
NOOP_TEST_CASES,
325
)
326
def test_parse_invalid_function(func: str) -> None:
327
# functions we don't (yet?) offer suggestions for
328
parser = BytecodeParser(eval(func), map_target="expr")
329
assert not parser.can_attempt_rewrite() or not parser.to_expression("x")
330
331
332
@pytest.mark.parametrize(
333
("col", "func", "expr_repr", "dtype"),
334
TEST_CASES,
335
)
336
@pytest.mark.filterwarnings(
337
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
338
"ignore:invalid value encountered:RuntimeWarning",
339
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",
340
)
341
@pytest.mark.may_fail_auto_streaming # dtype not set
342
@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set
343
def test_parse_apply_functions(
344
col: str, func: str, expr_repr: str, dtype: Literal["self"] | pl.DataType | None
345
) -> None:
346
return_dtype: pl.DataTypeExpr | None = None
347
if dtype == "self":
348
return_dtype = pl.self_dtype()
349
elif dtype is None:
350
return_dtype = None
351
else:
352
return_dtype = dtype.to_dtype_expr() # type: ignore[union-attr]
353
with pytest.warns(
354
PolarsInefficientMapWarning,
355
match=r"(?s)Expr\.map_elements.*with this one instead",
356
):
357
parser = BytecodeParser(eval(func), map_target="expr")
358
suggested_expression = parser.to_expression(col)
359
assert suggested_expression == expr_repr
360
361
df = pl.DataFrame(
362
{
363
"a": [1, 2, 3],
364
"b": ["AB", "cd", "eF"],
365
"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],
366
"d": ["2020-01-01", "2020-01-02", "2020-01-03"],
367
"e": [0.5, 0.4, 0.1],
368
"f": [
369
datetime(1969, 12, 31),
370
datetime(2024, 5, 6),
371
datetime(2077, 10, 20),
372
],
373
}
374
)
375
376
result_frame = df.select(
377
x=col,
378
y=eval(suggested_expression, EVAL_ENVIRONMENT),
379
)
380
expected_frame = df.select(
381
x=pl.col(col),
382
y=pl.col(col).map_elements(eval(func), return_dtype=return_dtype),
383
)
384
assert_frame_equal(
385
result_frame,
386
expected_frame,
387
check_dtypes=(".dt." not in suggested_expression),
388
)
389
390
391
@pytest.mark.filterwarnings(
392
"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",
393
"ignore:invalid value encountered:RuntimeWarning",
394
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",
395
)
396
@pytest.mark.may_fail_auto_streaming # dtype is not set
397
def test_parse_apply_raw_functions() -> None:
398
lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]})
399
400
# test bare 'numpy' functions
401
for func_name in _NUMPY_FUNCTIONS:
402
func = getattr(np, func_name)
403
404
# note: we can't parse/rewrite raw numpy functions...
405
parser = BytecodeParser(func, map_target="expr")
406
assert not parser.can_attempt_rewrite()
407
408
# ...but we ARE still able to warn
409
with pytest.warns(
410
PolarsInefficientMapWarning,
411
match=rf"(?s)Expr\.map_elements.*Replace this expression.*np\.{func_name}",
412
):
413
df1 = lf.select(
414
pl.col("a").map_elements(func, return_dtype=pl.self_dtype())
415
).collect()
416
df2 = lf.select(getattr(pl.col("a"), func_name)()).collect()
417
assert_frame_equal(df1, df2)
418
419
# test bare 'json.loads'
420
result_frames = []
421
with pytest.warns(
422
PolarsInefficientMapWarning,
423
match=r"(?s)Expr\.map_elements.*with this one instead:.*\.str\.json_decode",
424
):
425
for expr in (
426
pl.col("value").str.json_decode(
427
pl.Struct(
428
{
429
"a": pl.Int64,
430
"b": pl.Boolean,
431
"c": pl.String,
432
}
433
)
434
),
435
pl.col("value").map_elements(
436
json.loads,
437
return_dtype=pl.Struct(
438
{
439
"a": pl.Int64,
440
"b": pl.Boolean,
441
"c": pl.String,
442
}
443
),
444
),
445
):
446
result_frames.append( # noqa: PERF401
447
pl.LazyFrame({"value": ['{"a":1, "b": true, "c": "xx"}', None]})
448
.select(extracted=expr)
449
.unnest("extracted")
450
.collect()
451
)
452
453
assert_frame_equal(*result_frames)
454
455
# test primitive python casts
456
for py_cast, pl_dtype in ((str, pl.String), (int, pl.Int64), (float, pl.Float64)):
457
with pytest.warns(
458
PolarsInefficientMapWarning,
459
match=rf'(?s)with this one instead.*pl\.col\("a"\)\.cast\(pl\.{pl_dtype.__name__}\)',
460
):
461
assert_frame_equal(
462
lf.select(
463
pl.col("a").map_elements(py_cast, return_dtype=pl_dtype)
464
).collect(),
465
lf.select(pl.col("a").cast(pl_dtype)).collect(),
466
)
467
468
469
def test_parse_apply_miscellaneous() -> None:
470
# note: can also identify inefficient functions and methods as well as lambdas
471
class Test:
472
def x10(self, x: float) -> float:
473
return x * 10
474
475
def mcosh(self, x: float) -> float:
476
return cosh(x)
477
478
parser = BytecodeParser(Test().x10, map_target="expr")
479
suggested_expression = parser.to_expression(col="colx")
480
assert suggested_expression == 'pl.col("colx") * 10'
481
482
with pytest.warns(
483
PolarsInefficientMapWarning,
484
match=r"(?s)Series\.map_elements.*with this one instead.*s\.cosh\(\)",
485
):
486
pl.Series("colx", [0.5, 0.25]).map_elements(
487
function=Test().mcosh,
488
return_dtype=pl.Float64,
489
)
490
491
# note: all constants - should not create a warning/suggestion
492
suggested_expression = BytecodeParser(
493
lambda x: MY_CONSTANT + 42, map_target="expr"
494
).to_expression(col="colx")
495
assert suggested_expression is None
496
497
# literals as method parameters
498
with pytest.warns(
499
PolarsInefficientMapWarning,
500
match=r"(?s)Series\.map_elements.*with this one instead.*\(np\.cos\(3\) \+ s\) - abs\(-1\)",
501
):
502
s = pl.Series("srs", [0, 1, 2, 3, 4])
503
assert_series_equal(
504
s.map_elements(lambda x: np.cos(3) + x - abs(-1), return_dtype=pl.Float64),
505
np.cos(3) + s - 1,
506
)
507
508
# if 's' is already the name of a global variable then the series alias
509
# used in the user warning will fall back (in priority order) through
510
# various aliases until it finds one that is available.
511
s, srs, series = -1, 0, 1 # type: ignore[assignment]
512
expr1 = BytecodeParser(lambda x: x + s, map_target="series")
513
expr2 = BytecodeParser(lambda x: srs + x + s, map_target="series")
514
expr3 = BytecodeParser(lambda x: srs + x + s - x + series, map_target="series")
515
516
assert expr1.to_expression(col="srs") == "srs + s"
517
assert expr2.to_expression(col="srs") == "(srs + series) + s"
518
assert expr3.to_expression(col="srs") == "(((srs + srs0) + s) - srs0) + series"
519
520
521
@pytest.mark.parametrize(
522
("name", "data", "func", "expr_repr"),
523
[
524
(
525
"srs",
526
[1, 2, 3],
527
lambda x: str(x),
528
"s.cast(pl.String)",
529
),
530
(
531
"s",
532
[date(2077, 10, 10), date(1999, 12, 31)],
533
lambda d: d.month,
534
"s.dt.month()",
535
),
536
(
537
"",
538
[-20, -12, -5, 0, 5, 12, 20],
539
lambda x: (abs(x) != 12) and (x > 10 or x < -10 or x == 0),
540
"(s.abs() != 12) & ((s > 10) | (s < -10) | (s == 0))",
541
),
542
],
543
)
544
@pytest.mark.filterwarnings(
545
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning"
546
)
547
def test_parse_apply_series(
548
name: str, data: list[Any], func: Callable[[Any], Any], expr_repr: str
549
) -> None:
550
# expression/series generate same warning, with 's' as the series placeholder
551
with pytest.warns(
552
PolarsInefficientMapWarning,
553
match=r"(?s)Series\.map_elements.*s\.\w+\(",
554
):
555
s = pl.Series(name, data)
556
557
parser = BytecodeParser(func, map_target="series")
558
suggested_expression = parser.to_expression(s.name)
559
assert suggested_expression == expr_repr
560
561
expected_series = s.map_elements(func)
562
result_series = eval(suggested_expression)
563
assert_series_equal(expected_series, result_series, check_dtypes=False)
564
565
566
@pytest.mark.may_fail_auto_streaming
567
def test_expr_exact_warning_message() -> None:
568
red, green, end_escape = (
569
("\x1b[31m", "\x1b[32m", "\x1b[0m")
570
if in_terminal_that_supports_colour()
571
else ("", "", "")
572
)
573
msg = re.escape(
574
"\n"
575
"Expr.map_elements is significantly slower than the native expressions API.\n"
576
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
577
"Replace this expression...\n"
578
f' {red}- pl.col("a").map_elements(lambda x: ...){end_escape}\n'
579
"with this one instead:\n"
580
f' {green}+ pl.col("a") + 1{end_escape}\n'
581
)
582
583
fn = lambda x: x + 1 # noqa: E731
584
585
# check the EXACT warning messages - if modifying the message in the future,
586
# make sure to keep the `^` and `$`, and the assertion on `len(warnings)`
587
with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings:
588
df = pl.DataFrame({"a": [1, 2, 3]})
589
for _ in range(3): # << loop a few times to exercise the caching path
590
df.select(pl.col("a").map_elements(fn, return_dtype=pl.Int64))
591
592
assert len(warnings) == 3
593
594
# confirm that the associated parser/etc was cached
595
bp = _BYTECODE_PARSER_CACHE_[(fn, "expr")]
596
assert isinstance(bp, BytecodeParser)
597
assert bp.to_expression("a") == 'pl.col("a") + 1'
598
599
600
def test_omit_implicit_bool() -> None:
601
parser = BytecodeParser(
602
function=lambda x: x and x and x.date(),
603
map_target="expr",
604
)
605
suggested_expression = parser.to_expression("d")
606
assert suggested_expression == 'pl.col("d").dt.date()'
607
608
609
def test_partial_functions_13523() -> None:
610
def plus(value: int, amount: int) -> int:
611
return value + amount
612
613
data = {"a": [1, 2], "b": [3, 4]}
614
df = pl.DataFrame(data)
615
# should not warn
616
_ = df["a"].map_elements(partial(plus, amount=1))
617
618