Path: blob/main/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py
8431 views
from __future__ import annotations12import datetime as dt3import json4import math5import re6from datetime import date, datetime7from functools import partial8from math import cosh9from typing import TYPE_CHECKING, Any, Literal1011import numpy as np12import pytest1314import polars as pl15from polars._utils.udfs import _BYTECODE_PARSER_CACHE_, _NUMPY_FUNCTIONS, BytecodeParser16from polars._utils.various import in_terminal_that_supports_colour17from polars.exceptions import PolarsInefficientMapWarning18from polars.testing import assert_frame_equal, assert_series_equal1920if TYPE_CHECKING:21from collections.abc import Callable2223MY_CONSTANT = 324MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}25MY_LIST = [1, 2, 3]2627# column_name, function, expected_suggestion28TEST_CASES = [29# ---------------------------------------------30# numeric expr: math, comparison, logic ops31# ---------------------------------------------32("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666', None),33("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2', None),34("a", "lambda x: x & True", 'pl.col("a") & True', None),35("a", "lambda x: x | False", 'pl.col("a") | False', None),36("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3', None),37("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1', None),38(39"a",40"lambda x: not (x > 1) or x == 2",41'~(pl.col("a") > 1) | (pl.col("a") == 2)',42None,43),44("a", "lambda x: x is None", 'pl.col("a") is None', None),45("a", "lambda x: x is not None", 'pl.col("a") is not None', None),46(47"a",48"lambda x: ((x * -x) ** x) * 1.0",49'((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0',50None,51),52(53"a",54"lambda x: 1.0 * (x * (x**x))",55'1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))',56None,57),58(59"a",60"lambda x: (x / x) + ((x * x) - x)",61'(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))',62None,63),64(65"a",66"lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))",67'(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))',68None,69),70("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None),71("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None),72(73"a",74"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",75'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',76None,77),78("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None),79(80"a",81"lambda x: (float(x) * int(x)) // 2",82'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',83None,84),85(86"a",87"lambda x: 1 / (1 + np.exp(-x))",88'1 / (1 + (-pl.col("a")).exp())',89None,90),91# ---------------------------------------------92# math module93# ---------------------------------------------94("e", "lambda x: math.asin(x)", 'pl.col("e").arcsin()', None),95("e", "lambda x: math.asinh(x)", 'pl.col("e").arcsinh()', None),96("e", "lambda x: math.atan(x)", 'pl.col("e").arctan()', None),97("e", "lambda x: math.atanh(x)", 'pl.col("e").arctanh()', "self"),98("e", "lambda x: math.cos(x)", 'pl.col("e").cos()', None),99("e", "lambda x: math.degrees(x)", 'pl.col("e").degrees()', None),100("e", "lambda x: math.exp(x)", 'pl.col("e").exp()', None),101("e", "lambda x: math.log(x)", 'pl.col("e").log()', None),102("e", "lambda x: math.log10(x)", 'pl.col("e").log10()', None),103("e", "lambda x: math.log1p(x)", 'pl.col("e").log1p()', None),104("e", "lambda x: math.radians(x)", 'pl.col("e").radians()', None),105("e", "lambda x: math.sin(x)", 'pl.col("e").sin()', None),106("e", "lambda x: math.sinh(x)", 'pl.col("e").sinh()', None),107("e", "lambda x: math.sqrt(x)", 'pl.col("e").sqrt()', None),108("e", "lambda x: math.tan(x)", 'pl.col("e").tan()', None),109("e", "lambda x: math.tanh(x)", 'pl.col("e").tanh()', None),110# ---------------------------------------------111# numpy module112# ---------------------------------------------113("e", "lambda x: np.arccos(x)", 'pl.col("e").arccos()', None),114("e", "lambda x: np.arccosh(x)", 'pl.col("e").arccosh()', None),115("e", "lambda x: np.arcsin(x)", 'pl.col("e").arcsin()', None),116("e", "lambda x: np.arcsinh(x)", 'pl.col("e").arcsinh()', None),117("e", "lambda x: np.arctan(x)", 'pl.col("e").arctan()', None),118("e", "lambda x: np.arctanh(x)", 'pl.col("e").arctanh()', "self"),119("a", "lambda x: 0 + np.cbrt(x)", '0 + pl.col("a").cbrt()', None),120("e", "lambda x: np.ceil(x)", 'pl.col("e").ceil()', None),121("e", "lambda x: np.cos(x)", 'pl.col("e").cos()', None),122("e", "lambda x: np.cosh(x)", 'pl.col("e").cosh()', None),123("e", "lambda x: np.degrees(x)", 'pl.col("e").degrees()', None),124("e", "lambda x: np.exp(x)", 'pl.col("e").exp()', None),125("e", "lambda x: np.floor(x)", 'pl.col("e").floor()', None),126("e", "lambda x: np.log(x)", 'pl.col("e").log()', None),127("e", "lambda x: np.log10(x)", 'pl.col("e").log10()', None),128("e", "lambda x: np.log1p(x)", 'pl.col("e").log1p()', None),129("e", "lambda x: np.radians(x)", 'pl.col("e").radians()', None),130("a", "lambda x: np.sign(x)", 'pl.col("a").sign()', None),131("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1', None),132(133"a", # note: functions operate on consts134"lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)",135'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',136None,137),138("a", "lambda x: np.sinh(x) + 1", 'pl.col("a").sinh() + 1', None),139("a", "lambda x: np.sqrt(x) + 1", 'pl.col("a").sqrt() + 1', None),140("a", "lambda x: np.tan(x) + 1", 'pl.col("a").tan() + 1', None),141("e", "lambda x: np.tanh(x)", 'pl.col("e").tanh()', None),142# ---------------------------------------------143# logical 'and/or' (validate nesting levels)144# ---------------------------------------------145(146"a",147"lambda x: x > 1 or (x == 1 and x == 2)",148'(pl.col("a") > 1) | ((pl.col("a") == 1) & (pl.col("a") == 2))',149None,150),151(152"a",153"lambda x: (x > 1 or x == 1) and x == 2",154'((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)',155None,156),157(158"a",159"lambda x: x > 2 or x != 3 and x not in (0, 1, 4)",160'(pl.col("a") > 2) | ((pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4)))',161None,162),163(164"a",165"lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3",166'((pl.col("a") > 1) & (pl.col("a") != 2)) | (((pl.col("a") % 2) == 0) & (pl.col("a") < 3))',167None,168),169(170"a",171"lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3",172'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',173None,174),175# ---------------------------------------------176# string exprs177# ---------------------------------------------178(179"b",180"lambda x: str(x).title()",181'pl.col("b").cast(pl.String).str.to_titlecase()',182None,183),184(185"b",186'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',187'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',188None,189),190(191"b",192"lambda x: x.strip().startswith('#')",193"""pl.col("b").str.strip_chars().str.starts_with('#')""",194None,195),196(197"b",198"""lambda x: x.rstrip().endswith(('!','#','?','"'))""",199"""pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""",200None,201),202(203"b",204"""lambda x: x.lstrip().startswith(('!','#','?',"'"))""",205"""pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""",206None,207),208(209"b",210"lambda x: x.replace(':','')",211"""pl.col("b").str.replace_all(':','',literal=True)""",212None,213),214(215"b",216"lambda x: x.replace(':','',2)",217"""pl.col("b").str.replace(':','',n=2,literal=True)""",218None,219),220(221"b",222"lambda x: x.removeprefix('A').removesuffix('F')",223"""pl.col("b").str.strip_prefix('A').str.strip_suffix('F')""",224None,225),226(227"b",228"lambda x: x.zfill(8)",229"""pl.col("b").str.zfill(8)""",230None,231),232# ---------------------------------------------233# replace234# ---------------------------------------------235("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace_strict(MY_DICT)', None),236(237"a",238"lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]",239'(pl.col("a") - 1).replace_strict(MY_DICT) + (1 + pl.col("a")).replace_strict(MY_DICT)',240None,241),242# ---------------------------------------------243# standard library datetime parsing244# ---------------------------------------------245(246"d",247'lambda x: datetime.strptime(x, "%Y-%m-%d")',248'pl.col("d").str.to_datetime(format="%Y-%m-%d")',249pl.Datetime("us"),250),251(252"d",253'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")',254'pl.col("d").str.to_datetime(format="%Y-%m-%d")',255pl.Datetime("us"),256),257# ---------------------------------------------258# temporal attributes/methods259# ---------------------------------------------260(261"f",262"lambda x: x.isoweekday()",263'pl.col("f").dt.weekday()',264None,265),266(267"f",268"lambda x: x.hour + x.minute + x.second",269'(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()',270None,271),272# ---------------------------------------------273# Bitwise shifts274# ---------------------------------------------275(276"a",277"lambda x: (3 << (30-x)) & 3",278'(3 * 2**(30 - pl.col("a"))).cast(pl.Int64) & 3',279None,280),281(282"a",283"lambda x: (x << 32) & 3",284'(pl.col("a") * 2**32).cast(pl.Int64) & 3',285None,286),287(288"a",289"lambda x: ((32-x) >> (3)) & 3",290'((32 - pl.col("a")) / 2**3).cast(pl.Int64) & 3',291None,292),293(294"a",295"lambda x: (32 >> (3-x)) & 3",296'(32 / 2**(3 - pl.col("a"))).cast(pl.Int64) & 3',297None,298),299]300301NOOP_TEST_CASES = [302"lambda x: x",303"lambda x, y: x + y",304"lambda x: x[0] + 1",305"lambda x: MY_LIST[x]",306"lambda x: MY_DICT[1]",307'lambda x: "first" if x == 1 else "not first"',308'lambda x: np.sign(x, casting="unsafe")',309]310311EVAL_ENVIRONMENT = {312"MY_CONSTANT": MY_CONSTANT,313"MY_DICT": MY_DICT,314"MY_LIST": MY_LIST,315"cosh": cosh,316"datetime": datetime,317"dt": dt,318"math": math,319"np": np,320"pl": pl,321}322323324@pytest.mark.parametrize(325"func",326NOOP_TEST_CASES,327)328def test_parse_invalid_function(func: str) -> None:329# functions we don't (yet?) offer suggestions for330parser = BytecodeParser(eval(func), map_target="expr")331assert not parser.can_attempt_rewrite() or not parser.to_expression("x")332333334@pytest.mark.parametrize(335("col", "func", "expr_repr", "dtype"),336TEST_CASES,337)338@pytest.mark.filterwarnings(339"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",340"ignore:invalid value encountered:RuntimeWarning",341"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",342)343@pytest.mark.may_fail_auto_streaming # dtype not set344@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set345def test_parse_apply_functions(346col: str, func: str, expr_repr: str, dtype: Literal["self"] | pl.DataType | None347) -> None:348return_dtype: pl.DataTypeExpr | None = None349if dtype == "self":350return_dtype = pl.self_dtype()351elif dtype is None:352return_dtype = None353else:354return_dtype = dtype.to_dtype_expr() # type: ignore[union-attr]355356parser = BytecodeParser(eval(func), map_target="expr")357suggested_expression = parser.to_expression(col)358assert suggested_expression == expr_repr359360df = pl.DataFrame(361{362"a": [1, 2, 3],363"b": ["AB", "cd", "eF"],364"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],365"d": ["2020-01-01", "2020-01-02", "2020-01-03"],366"e": [0.5, 0.4, 0.1],367"f": [368datetime(1969, 12, 31),369datetime(2024, 5, 6),370datetime(2077, 10, 20),371],372}373)374375result_frame = df.select(376x=col,377y=eval(suggested_expression, EVAL_ENVIRONMENT),378)379with pytest.warns(380PolarsInefficientMapWarning,381match=r"(?s)Expr\.map_elements.*with this one instead",382):383expected_frame = df.select(384x=pl.col(col),385y=pl.col(col).map_elements(eval(func), return_dtype=return_dtype),386)387assert_frame_equal(388result_frame,389expected_frame,390check_dtypes=(".dt." not in suggested_expression),391)392393394@pytest.mark.filterwarnings(395"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",396"ignore:invalid value encountered:RuntimeWarning",397"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",398)399@pytest.mark.may_fail_auto_streaming # dtype is not set400def test_parse_apply_raw_functions() -> None:401lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]})402403# test bare 'numpy' functions404for func_name in _NUMPY_FUNCTIONS:405func = getattr(np, func_name)406407# note: we can't parse/rewrite raw numpy functions...408parser = BytecodeParser(func, map_target="expr")409assert not parser.can_attempt_rewrite()410411# ...but we ARE still able to warn412with pytest.warns(413PolarsInefficientMapWarning,414match=rf"(?s)Expr\.map_elements.*Replace this expression.*np\.{func_name}",415):416df1 = lf.select(417pl.col("a").map_elements(func, return_dtype=pl.self_dtype())418).collect()419df2 = lf.select(getattr(pl.col("a"), func_name)()).collect()420assert_frame_equal(df1, df2)421422# test bare 'json.loads'423json_dtype = pl.Struct({"a": pl.Int64, "b": pl.Boolean, "c": pl.String})424expr_native = pl.col("value").str.json_decode(json_dtype)425with pytest.warns(426PolarsInefficientMapWarning,427match=r"(?s)Expr\.map_elements.*with this one instead:.*\.str\.json_decode",428):429expr_pyfunc = pl.col("value").map_elements(json.loads, return_dtype=json_dtype)430431result_frames = [432pl.LazyFrame({"value": ['{"a":1, "b": true, "c": "xx"}', None]})433.select(extracted=expr)434.unnest("extracted")435.collect()436for expr in (expr_native, expr_pyfunc)437]438assert_frame_equal(*result_frames)439440# test primitive python casts441for py_cast, pl_dtype in ((str, pl.String), (int, pl.Int64), (float, pl.Float64)):442with pytest.warns(443PolarsInefficientMapWarning,444match=rf'(?s)with this one instead.*pl\.col\("a"\)\.cast\(pl\.{pl_dtype.__name__}\)',445):446assert_frame_equal(447lf.select(448pl.col("a").map_elements(py_cast, return_dtype=pl_dtype)449).collect(),450lf.select(pl.col("a").cast(pl_dtype)).collect(),451)452453454def test_parse_apply_miscellaneous() -> None:455# note: can also identify inefficient functions and methods as well as lambdas456class Test:457def x10(self, x: float) -> float:458return x * 10459460def mcosh(self, x: float) -> float:461return cosh(x)462463parser = BytecodeParser(Test().x10, map_target="expr")464suggested_expression = parser.to_expression(col="colx")465assert suggested_expression == 'pl.col("colx") * 10'466467with pytest.warns(468PolarsInefficientMapWarning,469match=r"(?s)Series\.map_elements.*with this one instead.*s\.cosh\(\)",470):471pl.Series("colx", [0.5, 0.25]).map_elements(472function=Test().mcosh,473return_dtype=pl.Float64,474)475476# note: all constants - should not create a warning/suggestion477suggested_expression = BytecodeParser(478lambda x: MY_CONSTANT + 42, map_target="expr"479).to_expression(col="colx")480assert suggested_expression is None481482# literals as method parameters483s = pl.Series("srs", [0, 1, 2, 3, 4])484with pytest.warns(485PolarsInefficientMapWarning,486match=r"(?s)Series\.map_elements.*with this one instead.*\(np\.cos\(3\) \+ s\) - abs\(-1\)",487):488assert_series_equal(489s.map_elements(lambda x: np.cos(3) + x - abs(-1), return_dtype=pl.Float64),490np.cos(3) + s - 1,491)492493# if 's' is already the name of a global variable then the series alias494# used in the user warning will fall back (in priority order) through495# various aliases until it finds one that is available.496s, srs, series = -1, 0, 1 # type: ignore[assignment]497expr1 = BytecodeParser(lambda x: x + s, map_target="series")498expr2 = BytecodeParser(lambda x: srs + x + s, map_target="series")499expr3 = BytecodeParser(lambda x: srs + x + s - x + series, map_target="series")500501assert expr1.to_expression(col="srs") == "srs + s"502assert expr2.to_expression(col="srs") == "(srs + series) + s"503assert expr3.to_expression(col="srs") == "(((srs + srs0) + s) - srs0) + series"504505506@pytest.mark.parametrize(507("name", "data", "func", "expr_repr"),508[509(510"srs",511[1, 2, 3],512lambda x: str(x),513"s.cast(pl.String)",514),515(516"s",517[date(2077, 10, 10), date(1999, 12, 31)],518lambda d: d.month,519"s.dt.month()",520),521(522"",523[-20, -12, -5, 0, 5, 12, 20],524lambda x: (abs(x) != 12) and (x > 10 or x < -10 or x == 0),525"(s.abs() != 12) & ((s > 10) | (s < -10) | (s == 0))",526),527],528)529@pytest.mark.filterwarnings(530"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning"531)532def test_parse_apply_series(533name: str, data: list[Any], func: Callable[[Any], Any], expr_repr: str534) -> None:535# expression/series generate same warning, with 's' as the series placeholder536s = pl.Series(name, data)537538parser = BytecodeParser(func, map_target="series")539suggested_expression = parser.to_expression(s.name)540assert suggested_expression == expr_repr541542with pytest.warns(543PolarsInefficientMapWarning,544match=r"(?s)Series\.map_elements.*s\.\w+\(",545):546expected_series = s.map_elements(func)547548result_series = eval(suggested_expression)549assert_series_equal(expected_series, result_series, check_dtypes=False)550551552@pytest.mark.may_fail_auto_streaming553def test_expr_exact_warning_message() -> None:554red, green, end_escape = (555("\x1b[31m", "\x1b[32m", "\x1b[0m")556if in_terminal_that_supports_colour()557else ("", "", "")558)559msg = re.escape(560"\n"561"Expr.map_elements is significantly slower than the native expressions API.\n"562"Only use if you absolutely CANNOT implement your logic otherwise.\n"563"Replace this expression...\n"564f' {red}- pl.col("a").map_elements(lambda x: ...){end_escape}\n'565"with this one instead:\n"566f' {green}+ pl.col("a") + 1{end_escape}\n'567)568569fn = lambda x: x + 1 # noqa: E731570df = pl.DataFrame({"a": [1, 2, 3]})571572# check the EXACT warning messages - if modifying the message in the future,573# make sure to keep the `^` and `$`, and the assertion on `len(warnings)`574with pytest.warns( # noqa: PT031575PolarsInefficientMapWarning,576match=rf"^{msg}$",577) as warnings:578for _ in range(3): # << loop a few times to exercise the caching path579df.select(pl.col("a").map_elements(fn, return_dtype=pl.Int64))580581assert len(warnings) == 3582583# confirm that the associated parser/etc was cached584bp = _BYTECODE_PARSER_CACHE_[(fn, "expr")]585assert isinstance(bp, BytecodeParser)586assert bp.to_expression("a") == 'pl.col("a") + 1'587588589def test_omit_implicit_bool() -> None:590parser = BytecodeParser(591function=lambda x: x and x and x.date(),592map_target="expr",593)594suggested_expression = parser.to_expression("d")595assert suggested_expression == 'pl.col("d").dt.date()'596597598def test_partial_functions_13523() -> None:599def plus(value: int, amount: int) -> int:600return value + amount601602data = {"a": [1, 2], "b": [3, 4]}603df = pl.DataFrame(data)604# should not warn605_ = df["a"].map_elements(partial(plus, amount=1))606607608