Path: blob/main/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py
6940 views
from __future__ import annotations12import datetime as dt3import json4import math5import re6from datetime import date, datetime7from functools import partial8from math import cosh9from typing import Any, Callable, Literal1011import numpy as np12import pytest1314import polars as pl15from polars._utils.udfs import _BYTECODE_PARSER_CACHE_, _NUMPY_FUNCTIONS, BytecodeParser16from polars._utils.various import in_terminal_that_supports_colour17from polars.exceptions import PolarsInefficientMapWarning18from polars.testing import assert_frame_equal, assert_series_equal1920MY_CONSTANT = 321MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}22MY_LIST = [1, 2, 3]2324# column_name, function, expected_suggestion25TEST_CASES = [26# ---------------------------------------------27# numeric expr: math, comparison, logic ops28# ---------------------------------------------29("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666', None),30("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2', None),31("a", "lambda x: x & True", 'pl.col("a") & True', None),32("a", "lambda x: x | False", 'pl.col("a") | False', None),33("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3', None),34("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1', None),35(36"a",37"lambda x: not (x > 1) or x == 2",38'~(pl.col("a") > 1) | (pl.col("a") == 2)',39None,40),41("a", "lambda x: x is None", 'pl.col("a") is None', None),42("a", "lambda x: x is not None", 'pl.col("a") is not None', None),43(44"a",45"lambda x: ((x * -x) ** x) * 1.0",46'((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0',47None,48),49(50"a",51"lambda x: 1.0 * (x * (x**x))",52'1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))',53None,54),55(56"a",57"lambda x: (x / x) + ((x * x) - x)",58'(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))',59None,60),61(62"a",63"lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))",64'(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))',65None,66),67("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))', None),68("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))', None),69(70"a",71"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",72'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',73None,74),75("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")', None),76(77"a",78"lambda x: (float(x) * int(x)) // 2",79'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',80None,81),82(83"a",84"lambda x: 1 / (1 + np.exp(-x))",85'1 / (1 + (-pl.col("a")).exp())',86None,87),88# ---------------------------------------------89# math module90# ---------------------------------------------91("e", "lambda x: math.asin(x)", 'pl.col("e").arcsin()', None),92("e", "lambda x: math.asinh(x)", 'pl.col("e").arcsinh()', None),93("e", "lambda x: math.atan(x)", 'pl.col("e").arctan()', None),94("e", "lambda x: math.atanh(x)", 'pl.col("e").arctanh()', "self"),95("e", "lambda x: math.cos(x)", 'pl.col("e").cos()', None),96("e", "lambda x: math.degrees(x)", 'pl.col("e").degrees()', None),97("e", "lambda x: math.exp(x)", 'pl.col("e").exp()', None),98("e", "lambda x: math.log(x)", 'pl.col("e").log()', None),99("e", "lambda x: math.log10(x)", 'pl.col("e").log10()', None),100("e", "lambda x: math.log1p(x)", 'pl.col("e").log1p()', None),101("e", "lambda x: math.radians(x)", 'pl.col("e").radians()', None),102("e", "lambda x: math.sin(x)", 'pl.col("e").sin()', None),103("e", "lambda x: math.sinh(x)", 'pl.col("e").sinh()', None),104("e", "lambda x: math.sqrt(x)", 'pl.col("e").sqrt()', None),105("e", "lambda x: math.tan(x)", 'pl.col("e").tan()', None),106("e", "lambda x: math.tanh(x)", 'pl.col("e").tanh()', None),107# ---------------------------------------------108# numpy module109# ---------------------------------------------110("e", "lambda x: np.arccos(x)", 'pl.col("e").arccos()', None),111("e", "lambda x: np.arccosh(x)", 'pl.col("e").arccosh()', None),112("e", "lambda x: np.arcsin(x)", 'pl.col("e").arcsin()', None),113("e", "lambda x: np.arcsinh(x)", 'pl.col("e").arcsinh()', None),114("e", "lambda x: np.arctan(x)", 'pl.col("e").arctan()', None),115("e", "lambda x: np.arctanh(x)", 'pl.col("e").arctanh()', "self"),116("a", "lambda x: 0 + np.cbrt(x)", '0 + pl.col("a").cbrt()', None),117("e", "lambda x: np.ceil(x)", 'pl.col("e").ceil()', None),118("e", "lambda x: np.cos(x)", 'pl.col("e").cos()', None),119("e", "lambda x: np.cosh(x)", 'pl.col("e").cosh()', None),120("e", "lambda x: np.degrees(x)", 'pl.col("e").degrees()', None),121("e", "lambda x: np.exp(x)", 'pl.col("e").exp()', None),122("e", "lambda x: np.floor(x)", 'pl.col("e").floor()', None),123("e", "lambda x: np.log(x)", 'pl.col("e").log()', None),124("e", "lambda x: np.log10(x)", 'pl.col("e").log10()', None),125("e", "lambda x: np.log1p(x)", 'pl.col("e").log1p()', None),126("e", "lambda x: np.radians(x)", 'pl.col("e").radians()', None),127("a", "lambda x: np.sign(x)", 'pl.col("a").sign()', None),128("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1', None),129(130"a", # note: functions operate on consts131"lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)",132'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',133None,134),135("a", "lambda x: np.sinh(x) + 1", 'pl.col("a").sinh() + 1', None),136("a", "lambda x: np.sqrt(x) + 1", 'pl.col("a").sqrt() + 1', None),137("a", "lambda x: np.tan(x) + 1", 'pl.col("a").tan() + 1', None),138("e", "lambda x: np.tanh(x)", 'pl.col("e").tanh()', None),139# ---------------------------------------------140# logical 'and/or' (validate nesting levels)141# ---------------------------------------------142(143"a",144"lambda x: x > 1 or (x == 1 and x == 2)",145'(pl.col("a") > 1) | ((pl.col("a") == 1) & (pl.col("a") == 2))',146None,147),148(149"a",150"lambda x: (x > 1 or x == 1) and x == 2",151'((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)',152None,153),154(155"a",156"lambda x: x > 2 or x != 3 and x not in (0, 1, 4)",157'(pl.col("a") > 2) | ((pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4)))',158None,159),160(161"a",162"lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3",163'((pl.col("a") > 1) & (pl.col("a") != 2)) | (((pl.col("a") % 2) == 0) & (pl.col("a") < 3))',164None,165),166(167"a",168"lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3",169'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',170None,171),172# ---------------------------------------------173# string exprs174# ---------------------------------------------175(176"b",177"lambda x: str(x).title()",178'pl.col("b").cast(pl.String).str.to_titlecase()',179None,180),181(182"b",183'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',184'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',185None,186),187(188"b",189"lambda x: x.strip().startswith('#')",190"""pl.col("b").str.strip_chars().str.starts_with('#')""",191None,192),193(194"b",195"""lambda x: x.rstrip().endswith(('!','#','?','"'))""",196"""pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""",197None,198),199(200"b",201"""lambda x: x.lstrip().startswith(('!','#','?',"'"))""",202"""pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""",203None,204),205(206"b",207"lambda x: x.replace(':','')",208"""pl.col("b").str.replace_all(':','',literal=True)""",209None,210),211(212"b",213"lambda x: x.replace(':','',2)",214"""pl.col("b").str.replace(':','',n=2,literal=True)""",215None,216),217(218"b",219"lambda x: x.removeprefix('A').removesuffix('F')",220"""pl.col("b").str.strip_prefix('A').str.strip_suffix('F')""",221None,222),223(224"b",225"lambda x: x.zfill(8)",226"""pl.col("b").str.zfill(8)""",227None,228),229# ---------------------------------------------230# replace231# ---------------------------------------------232("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace_strict(MY_DICT)', None),233(234"a",235"lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]",236'(pl.col("a") - 1).replace_strict(MY_DICT) + (1 + pl.col("a")).replace_strict(MY_DICT)',237None,238),239# ---------------------------------------------240# standard library datetime parsing241# ---------------------------------------------242(243"d",244'lambda x: datetime.strptime(x, "%Y-%m-%d")',245'pl.col("d").str.to_datetime(format="%Y-%m-%d")',246pl.Datetime("us"),247),248(249"d",250'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")',251'pl.col("d").str.to_datetime(format="%Y-%m-%d")',252pl.Datetime("us"),253),254# ---------------------------------------------255# temporal attributes/methods256# ---------------------------------------------257(258"f",259"lambda x: x.isoweekday()",260'pl.col("f").dt.weekday()',261None,262),263(264"f",265"lambda x: x.hour + x.minute + x.second",266'(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()',267None,268),269# ---------------------------------------------270# Bitwise shifts271# ---------------------------------------------272(273"a",274"lambda x: (3 << (30-x)) & 3",275'(3 * 2**(30 - pl.col("a"))).cast(pl.Int64) & 3',276None,277),278(279"a",280"lambda x: (x << 32) & 3",281'(pl.col("a") * 2**32).cast(pl.Int64) & 3',282None,283),284(285"a",286"lambda x: ((32-x) >> (3)) & 3",287'((32 - pl.col("a")) / 2**3).cast(pl.Int64) & 3',288None,289),290(291"a",292"lambda x: (32 >> (3-x)) & 3",293'(32 / 2**(3 - pl.col("a"))).cast(pl.Int64) & 3',294None,295),296]297298NOOP_TEST_CASES = [299"lambda x: x",300"lambda x, y: x + y",301"lambda x: x[0] + 1",302"lambda x: MY_LIST[x]",303"lambda x: MY_DICT[1]",304'lambda x: "first" if x == 1 else "not first"',305'lambda x: np.sign(x, casting="unsafe")',306]307308EVAL_ENVIRONMENT = {309"MY_CONSTANT": MY_CONSTANT,310"MY_DICT": MY_DICT,311"MY_LIST": MY_LIST,312"cosh": cosh,313"datetime": datetime,314"dt": dt,315"math": math,316"np": np,317"pl": pl,318}319320321@pytest.mark.parametrize(322"func",323NOOP_TEST_CASES,324)325def test_parse_invalid_function(func: str) -> None:326# functions we don't (yet?) offer suggestions for327parser = BytecodeParser(eval(func), map_target="expr")328assert not parser.can_attempt_rewrite() or not parser.to_expression("x")329330331@pytest.mark.parametrize(332("col", "func", "expr_repr", "dtype"),333TEST_CASES,334)335@pytest.mark.filterwarnings(336"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",337"ignore:invalid value encountered:RuntimeWarning",338"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",339)340@pytest.mark.may_fail_auto_streaming # dtype not set341@pytest.mark.may_fail_cloud # reason: eager - return_dtype must be set342def test_parse_apply_functions(343col: str, func: str, expr_repr: str, dtype: Literal["self"] | pl.DataType | None344) -> None:345return_dtype: pl.DataTypeExpr | None = None346if dtype == "self":347return_dtype = pl.self_dtype()348elif dtype is None:349return_dtype = None350else:351return_dtype = dtype.to_dtype_expr() # type: ignore[union-attr]352with pytest.warns(353PolarsInefficientMapWarning,354match=r"(?s)Expr\.map_elements.*with this one instead",355):356parser = BytecodeParser(eval(func), map_target="expr")357suggested_expression = parser.to_expression(col)358assert suggested_expression == expr_repr359360df = pl.DataFrame(361{362"a": [1, 2, 3],363"b": ["AB", "cd", "eF"],364"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],365"d": ["2020-01-01", "2020-01-02", "2020-01-03"],366"e": [0.5, 0.4, 0.1],367"f": [368datetime(1969, 12, 31),369datetime(2024, 5, 6),370datetime(2077, 10, 20),371],372}373)374375result_frame = df.select(376x=col,377y=eval(suggested_expression, EVAL_ENVIRONMENT),378)379expected_frame = df.select(380x=pl.col(col),381y=pl.col(col).map_elements(eval(func), return_dtype=return_dtype),382)383assert_frame_equal(384result_frame,385expected_frame,386check_dtypes=(".dt." not in suggested_expression),387)388389390@pytest.mark.filterwarnings(391"ignore:.*:polars.exceptions.MapWithoutReturnDtypeWarning",392"ignore:invalid value encountered:RuntimeWarning",393"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",394)395@pytest.mark.may_fail_auto_streaming # dtype is not set396def test_parse_apply_raw_functions() -> None:397lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]})398399# test bare 'numpy' functions400for func_name in _NUMPY_FUNCTIONS:401func = getattr(np, func_name)402403# note: we can't parse/rewrite raw numpy functions...404parser = BytecodeParser(func, map_target="expr")405assert not parser.can_attempt_rewrite()406407# ...but we ARE still able to warn408with pytest.warns(409PolarsInefficientMapWarning,410match=rf"(?s)Expr\.map_elements.*Replace this expression.*np\.{func_name}",411):412df1 = lf.select(413pl.col("a").map_elements(func, return_dtype=pl.self_dtype())414).collect()415df2 = lf.select(getattr(pl.col("a"), func_name)()).collect()416assert_frame_equal(df1, df2)417418# test bare 'json.loads'419result_frames = []420with pytest.warns(421PolarsInefficientMapWarning,422match=r"(?s)Expr\.map_elements.*with this one instead:.*\.str\.json_decode",423):424for expr in (425pl.col("value").str.json_decode(426pl.Struct(427{428"a": pl.Int64,429"b": pl.Boolean,430"c": pl.String,431}432)433),434pl.col("value").map_elements(435json.loads,436return_dtype=pl.Struct(437{438"a": pl.Int64,439"b": pl.Boolean,440"c": pl.String,441}442),443),444):445result_frames.append( # noqa: PERF401446pl.LazyFrame({"value": ['{"a":1, "b": true, "c": "xx"}', None]})447.select(extracted=expr)448.unnest("extracted")449.collect()450)451452assert_frame_equal(*result_frames)453454# test primitive python casts455for py_cast, pl_dtype in ((str, pl.String), (int, pl.Int64), (float, pl.Float64)):456with pytest.warns(457PolarsInefficientMapWarning,458match=rf'(?s)with this one instead.*pl\.col\("a"\)\.cast\(pl\.{pl_dtype.__name__}\)',459):460assert_frame_equal(461lf.select(462pl.col("a").map_elements(py_cast, return_dtype=pl_dtype)463).collect(),464lf.select(pl.col("a").cast(pl_dtype)).collect(),465)466467468def test_parse_apply_miscellaneous() -> None:469# note: can also identify inefficient functions and methods as well as lambdas470class Test:471def x10(self, x: float) -> float:472return x * 10473474def mcosh(self, x: float) -> float:475return cosh(x)476477parser = BytecodeParser(Test().x10, map_target="expr")478suggested_expression = parser.to_expression(col="colx")479assert suggested_expression == 'pl.col("colx") * 10'480481with pytest.warns(482PolarsInefficientMapWarning,483match=r"(?s)Series\.map_elements.*with this one instead.*s\.cosh\(\)",484):485pl.Series("colx", [0.5, 0.25]).map_elements(486function=Test().mcosh,487return_dtype=pl.Float64,488)489490# note: all constants - should not create a warning/suggestion491suggested_expression = BytecodeParser(492lambda x: MY_CONSTANT + 42, map_target="expr"493).to_expression(col="colx")494assert suggested_expression is None495496# literals as method parameters497with pytest.warns(498PolarsInefficientMapWarning,499match=r"(?s)Series\.map_elements.*with this one instead.*\(np\.cos\(3\) \+ s\) - abs\(-1\)",500):501s = pl.Series("srs", [0, 1, 2, 3, 4])502assert_series_equal(503s.map_elements(lambda x: np.cos(3) + x - abs(-1), return_dtype=pl.Float64),504np.cos(3) + s - 1,505)506507# if 's' is already the name of a global variable then the series alias508# used in the user warning will fall back (in priority order) through509# various aliases until it finds one that is available.510s, srs, series = -1, 0, 1 # type: ignore[assignment]511expr1 = BytecodeParser(lambda x: x + s, map_target="series")512expr2 = BytecodeParser(lambda x: srs + x + s, map_target="series")513expr3 = BytecodeParser(lambda x: srs + x + s - x + series, map_target="series")514515assert expr1.to_expression(col="srs") == "srs + s"516assert expr2.to_expression(col="srs") == "(srs + series) + s"517assert expr3.to_expression(col="srs") == "(((srs + srs0) + s) - srs0) + series"518519520@pytest.mark.parametrize(521("name", "data", "func", "expr_repr"),522[523(524"srs",525[1, 2, 3],526lambda x: str(x),527"s.cast(pl.String)",528),529(530"s",531[date(2077, 10, 10), date(1999, 12, 31)],532lambda d: d.month,533"s.dt.month()",534),535(536"",537[-20, -12, -5, 0, 5, 12, 20],538lambda x: (abs(x) != 12) and (x > 10 or x < -10 or x == 0),539"(s.abs() != 12) & ((s > 10) | (s < -10) | (s == 0))",540),541],542)543@pytest.mark.filterwarnings(544"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning"545)546def test_parse_apply_series(547name: str, data: list[Any], func: Callable[[Any], Any], expr_repr: str548) -> None:549# expression/series generate same warning, with 's' as the series placeholder550with pytest.warns(551PolarsInefficientMapWarning,552match=r"(?s)Series\.map_elements.*s\.\w+\(",553):554s = pl.Series(name, data)555556parser = BytecodeParser(func, map_target="series")557suggested_expression = parser.to_expression(s.name)558assert suggested_expression == expr_repr559560expected_series = s.map_elements(func)561result_series = eval(suggested_expression)562assert_series_equal(expected_series, result_series, check_dtypes=False)563564565@pytest.mark.may_fail_auto_streaming566def test_expr_exact_warning_message() -> None:567red, green, end_escape = (568("\x1b[31m", "\x1b[32m", "\x1b[0m")569if in_terminal_that_supports_colour()570else ("", "", "")571)572msg = re.escape(573"\n"574"Expr.map_elements is significantly slower than the native expressions API.\n"575"Only use if you absolutely CANNOT implement your logic otherwise.\n"576"Replace this expression...\n"577f' {red}- pl.col("a").map_elements(lambda x: ...){end_escape}\n'578"with this one instead:\n"579f' {green}+ pl.col("a") + 1{end_escape}\n'580)581582fn = lambda x: x + 1 # noqa: E731583584# check the EXACT warning messages - if modifying the message in the future,585# make sure to keep the `^` and `$`, and the assertion on `len(warnings)`586with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings:587df = pl.DataFrame({"a": [1, 2, 3]})588for _ in range(3): # << loop a few times to exercise the caching path589df.select(pl.col("a").map_elements(fn, return_dtype=pl.Int64))590591assert len(warnings) == 3592593# confirm that the associated parser/etc was cached594bp = _BYTECODE_PARSER_CACHE_[(fn, "expr")]595assert isinstance(bp, BytecodeParser)596assert bp.to_expression("a") == 'pl.col("a") + 1'597598599def test_omit_implicit_bool() -> None:600parser = BytecodeParser(601function=lambda x: x and x and x.date(),602map_target="expr",603)604suggested_expression = parser.to_expression("d")605assert suggested_expression == 'pl.col("d").dt.date()'606607608def test_partial_functions_13523() -> None:609def plus(value: int, amount: int) -> int:610return value + amount611612data = {"a": [1, 2], "b": [3, 4]}613df = pl.DataFrame(data)614# should not warn615_ = df["a"].map_elements(partial(plus, amount=1))616617618