Path: blob/main/py-polars/tests/unit/operations/test_clip.py
6939 views
from __future__ import annotations12from datetime import datetime3from decimal import Decimal45import pytest67import polars as pl8from polars.exceptions import InvalidOperationError9from polars.testing import assert_frame_equal, assert_series_equal101112@pytest.fixture13def clip_exprs() -> list[pl.Expr]:14return [15pl.col("a").clip(pl.col("min"), pl.col("max")).alias("clip"),16pl.col("a").clip(lower_bound=pl.col("min")).alias("clip_min"),17pl.col("a").clip(upper_bound=pl.col("max")).alias("clip_max"),18]192021def test_clip_int(clip_exprs: list[pl.Expr]) -> None:22lf = pl.LazyFrame(23{24"a": [1, 2, 3, 4, 5, None],25"min": [0, -1, 4, None, 4, -10],26"max": [2, 1, 8, 5, None, 10],27}28)29result = lf.select(clip_exprs)30expected = pl.LazyFrame(31{32"clip": [1, 1, 4, 4, 5, None],33"clip_min": [1, 2, 4, 4, 5, None],34"clip_max": [1, 1, 3, 4, 5, None],35}36)37assert_frame_equal(result, expected)383940def test_clip_float(clip_exprs: list[pl.Expr]) -> None:41lf = pl.LazyFrame(42{43"a": [1.0, 2.0, 3.0, 4.0, 5.0, None],44"min": [0.0, -1.0, 4.0, None, 4.0, None],45"max": [2.0, 1.0, 8.0, 5.0, None, None],46}47)48result = lf.select(clip_exprs)49expected = pl.LazyFrame(50{51"clip": [1.0, 1.0, 4.0, 4.0, 5.0, None],52"clip_min": [1.0, 2.0, 4.0, 4.0, 5.0, None],53"clip_max": [1.0, 1.0, 3.0, 4.0, 5.0, None],54}55)56assert_frame_equal(result, expected)575859def test_clip_datetime(clip_exprs: list[pl.Expr]) -> None:60lf = pl.LazyFrame(61{62"a": [63datetime(1995, 6, 5, 10, 30),64datetime(1995, 6, 5),65datetime(2023, 10, 20, 18, 30, 6),66None,67datetime(2023, 9, 24),68datetime(2000, 1, 10),69],70"min": [71datetime(1995, 6, 5, 10, 29),72datetime(1996, 6, 5),73datetime(2020, 9, 24),74datetime(2020, 1, 1),75None,76datetime(2000, 1, 1),77],78"max": [79datetime(1995, 7, 21, 10, 30),80datetime(2000, 1, 1),81datetime(2023, 9, 20, 18, 30, 6),82datetime(2000, 1, 1),83datetime(1993, 3, 13),84None,85],86}87)88result = lf.select(clip_exprs)89expected = pl.LazyFrame(90{91"clip": [92datetime(1995, 6, 5, 10, 30),93datetime(1996, 6, 5),94datetime(2023, 9, 20, 18, 30, 6),95None,96datetime(1993, 3, 13),97datetime(2000, 1, 10),98],99"clip_min": [100datetime(1995, 6, 5, 10, 30),101datetime(1996, 6, 5),102datetime(2023, 10, 20, 18, 30, 6),103None,104datetime(2023, 9, 24),105datetime(2000, 1, 10),106],107"clip_max": [108datetime(1995, 6, 5, 10, 30),109datetime(1995, 6, 5),110datetime(2023, 9, 20, 18, 30, 6),111None,112datetime(1993, 3, 13),113datetime(2000, 1, 10),114],115}116)117assert_frame_equal(result, expected)118119120def test_clip_non_numeric_dtype_fails() -> None:121msg = "`clip` only supports physical numeric types"122123s = pl.Series(["a", "b", "c"])124with pytest.raises(InvalidOperationError, match=msg):125s.clip(pl.lit("b"), pl.lit("z"))126127128def test_clip_string_input() -> None:129df = pl.DataFrame({"a": [0, 1, 2], "min": [1, None, 1]})130result = df.select(pl.col("a").clip("min"))131expected = pl.DataFrame({"a": [1, 1, 2]})132assert_frame_equal(result, expected)133134135def test_clip_bound_invalid_for_original_dtype() -> None:136s = pl.Series([1, 2, 3, 4], dtype=pl.UInt32)137with pytest.raises(138InvalidOperationError, match="conversion from `i32` to `u32` failed"139):140s.clip(-1, 5)141142143def test_clip_decimal() -> None:144ser = pl.Series("a", ["1.1", "2.2", "3.3"], pl.Decimal(21, 1))145146result = ser.clip(lower_bound=Decimal("1.5"), upper_bound=Decimal("2.5"))147expected = pl.Series("a", ["1.5", "2.2", "2.5"], pl.Decimal(21, 1))148assert_series_equal(result, expected)149150result = ser.clip(lower_bound=Decimal("1.5"))151expected = pl.Series("a", ["1.5", "2.2", "3.3"], pl.Decimal(21, 1))152assert_series_equal(result, expected)153154result = ser.clip(upper_bound=Decimal("2.5"))155expected = pl.Series("a", ["1.1", "2.2", "2.5"], pl.Decimal(21, 1))156assert_series_equal(result, expected)157158159def test_clip_unequal_lengths_22018() -> None:160with pytest.raises(pl.exceptions.ShapeError):161pl.Series([1, 2, 3]).clip(lower_bound=pl.Series([1, 2]))162with pytest.raises(pl.exceptions.ShapeError):163pl.Series([1, 2, 3]).clip(upper_bound=pl.Series([1, 2]))164with pytest.raises(pl.exceptions.ShapeError):165pl.Series([1, 2, 3]).clip(pl.Series([1, 2]), pl.Series([1, 2, 3]))166with pytest.raises(pl.exceptions.ShapeError):167pl.Series([1, 2, 3]).clip(pl.Series([1, 2, 3]), pl.Series([1, 2]))168169170