Path: blob/main/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py
6940 views
from __future__ import annotations12from datetime import date, datetime, timedelta3from typing import TYPE_CHECKING45import hypothesis.strategies as st6import pytest7from hypothesis import given89import polars as pl10from polars._utils.convert import parse_as_duration_string11from polars.exceptions import ComputeError12from polars.testing import assert_series_equal1314if TYPE_CHECKING:15from polars._typing import TimeUnit161718@given(19value=st.datetimes(20min_value=datetime(1000, 1, 1),21max_value=datetime(3000, 1, 1),22),23n=st.integers(min_value=1, max_value=100),24)25def test_truncate_monthly(value: date, n: int) -> None:26result = pl.Series([value]).dt.truncate(f"{n}mo").item()27# manual calculation28total = (value.year - 1970) * 12 + value.month - 129remainder = total % n30total -= remainder31year, month = (total // 12) + 1970, ((total % 12) + 1)32expected = datetime(year, month, 1)33assert result == expected343536def test_truncate_date() -> None:37# n vs n38df = pl.DataFrame(39{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}40)41result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"]42expected = pl.Series("a", [None, None, date(2020, 1, 1)])43assert_series_equal(result, expected)4445# n vs 146df = pl.DataFrame(47{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}48)49result = df.select(pl.col("a").dt.truncate("1mo"))["a"]50expected = pl.Series("a", [date(2020, 1, 1), None, date(2020, 1, 1)])51assert_series_equal(result, expected)5253# n vs missing54df = pl.DataFrame(55{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}56)57result = df.select(pl.col("a").dt.truncate(pl.lit(None, dtype=pl.String)))["a"]58expected = pl.Series("a", [None, None, None], dtype=pl.Date)59assert_series_equal(result, expected)6061# 1 vs n62df = pl.DataFrame(63{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}64)65result = df.select(a=pl.date(2020, 1, 1).dt.truncate(pl.col("b")))["a"]66expected = pl.Series("a", [None, date(2020, 1, 1), date(2020, 1, 1)])67assert_series_equal(result, expected)6869# missing vs n70df = pl.DataFrame(71{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}72)73result = df.select(a=pl.lit(None, dtype=pl.Date).dt.truncate(pl.col("b")))["a"]74expected = pl.Series("a", [None, None, None], dtype=pl.Date)75assert_series_equal(result, expected)767778@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])79def test_truncate_datetime_simple(time_unit: TimeUnit) -> None:80s = pl.Series([datetime(2020, 1, 2, 6)], dtype=pl.Datetime(time_unit))81result = s.dt.truncate("1mo").item()82assert result == datetime(2020, 1, 1)83result = s.dt.truncate("1d").item()84assert result == datetime(2020, 1, 2)858687@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])88def test_truncate_datetime_w_expression(time_unit: TimeUnit) -> None:89df = pl.DataFrame(90{"a": [datetime(2020, 1, 2, 6), datetime(2020, 1, 3, 7)], "b": ["1mo", "1d"]},91schema_overrides={"a": pl.Datetime(time_unit)},92)93result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"]94assert result[0] == datetime(2020, 1, 1)95assert result[1] == datetime(2020, 1, 3)969798def test_pre_epoch_truncate_17581() -> None:99s = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1, 1)])100result = s.dt.truncate("1d")101expected = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1)])102assert_series_equal(result, expected)103104105@given(106datetimes=st.lists(107st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1)),108min_size=1,109max_size=3,110),111every=st.timedeltas(112min_value=timedelta(microseconds=1), max_value=timedelta(days=1)113).map(parse_as_duration_string),114)115def test_fast_path_vs_slow_path(datetimes: list[datetime], every: str) -> None:116s = pl.Series(datetimes)117# Might use fastpath:118result = s.dt.truncate(every)119# Definitely uses slowpath:120expected = s.dt.truncate(pl.Series([every] * len(datetimes)))121assert_series_equal(result, expected)122123124@pytest.mark.parametrize("as_date", [False, True])125def test_truncate_unequal_length_22018(as_date: bool) -> None:126s = pl.Series([datetime(2088, 8, 8, 8, 8, 8, 8)] * 2)127if as_date:128s = s.dt.date()129with pytest.raises(pl.exceptions.ShapeError):130s.dt.truncate(pl.Series(["1y"] * 3))131132133@pytest.mark.parametrize(134("multiplier", "unit", "value", "expected"),135[136(2, "h", datetime(1970, 1, 2, 3), datetime(1970, 1, 2, 2)),137(5, "h", datetime(1983, 3, 1, 4), datetime(1983, 3, 1, 2)),138(3, "d", datetime(1970, 1, 1), datetime(1970, 1, 1)),139(7, "d", datetime(2001, 1, 4, 5), datetime(2001, 1, 4)),140(11, "q", datetime(1, 9, 9, 9), datetime(1, 1, 1)),141(3, "y", datetime(1970, 1, 1), datetime(1970, 1, 1)),142(19, "y", datetime(9543, 1, 5, 6), datetime(9532, 1, 1)),143(5, "mo", datetime(1342, 11, 11, 11), datetime(1342, 7, 1)),144],145)146@pytest.mark.parametrize("time_zone", ["Asia/Kathmandu", None])147def test_truncate_origin_22590(148multiplier: int,149unit: str,150value: datetime,151expected: datetime,152time_zone: str | None,153) -> None:154result = (155pl.Series([value])156.dt.replace_time_zone(time_zone)157.dt.truncate(f"{multiplier}{unit}")158.dt.replace_time_zone(None)159.item()160)161assert result == expected, result162163164def test_truncate_invalid() -> None:165s = pl.Series([date(2020, 1, 1)])166with pytest.raises(ComputeError, match="cannot mix"):167s.dt.truncate("1d1h")168169170