Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py
6940 views
1
from __future__ import annotations
2
3
from datetime import date, datetime, timedelta
4
from typing import TYPE_CHECKING
5
6
import hypothesis.strategies as st
7
import pytest
8
from hypothesis import given
9
10
import polars as pl
11
from polars._utils.convert import parse_as_duration_string
12
from polars.exceptions import ComputeError
13
from polars.testing import assert_series_equal
14
15
if TYPE_CHECKING:
16
from polars._typing import TimeUnit
17
18
19
@given(
20
value=st.datetimes(
21
min_value=datetime(1000, 1, 1),
22
max_value=datetime(3000, 1, 1),
23
),
24
n=st.integers(min_value=1, max_value=100),
25
)
26
def test_truncate_monthly(value: date, n: int) -> None:
27
result = pl.Series([value]).dt.truncate(f"{n}mo").item()
28
# manual calculation
29
total = (value.year - 1970) * 12 + value.month - 1
30
remainder = total % n
31
total -= remainder
32
year, month = (total // 12) + 1970, ((total % 12) + 1)
33
expected = datetime(year, month, 1)
34
assert result == expected
35
36
37
def test_truncate_date() -> None:
38
# n vs n
39
df = pl.DataFrame(
40
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
41
)
42
result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"]
43
expected = pl.Series("a", [None, None, date(2020, 1, 1)])
44
assert_series_equal(result, expected)
45
46
# n vs 1
47
df = pl.DataFrame(
48
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
49
)
50
result = df.select(pl.col("a").dt.truncate("1mo"))["a"]
51
expected = pl.Series("a", [date(2020, 1, 1), None, date(2020, 1, 1)])
52
assert_series_equal(result, expected)
53
54
# n vs missing
55
df = pl.DataFrame(
56
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
57
)
58
result = df.select(pl.col("a").dt.truncate(pl.lit(None, dtype=pl.String)))["a"]
59
expected = pl.Series("a", [None, None, None], dtype=pl.Date)
60
assert_series_equal(result, expected)
61
62
# 1 vs n
63
df = pl.DataFrame(
64
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
65
)
66
result = df.select(a=pl.date(2020, 1, 1).dt.truncate(pl.col("b")))["a"]
67
expected = pl.Series("a", [None, date(2020, 1, 1), date(2020, 1, 1)])
68
assert_series_equal(result, expected)
69
70
# missing vs n
71
df = pl.DataFrame(
72
{"a": [date(2020, 1, 1), None, date(2020, 1, 3)], "b": [None, "1mo", "1mo"]}
73
)
74
result = df.select(a=pl.lit(None, dtype=pl.Date).dt.truncate(pl.col("b")))["a"]
75
expected = pl.Series("a", [None, None, None], dtype=pl.Date)
76
assert_series_equal(result, expected)
77
78
79
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
80
def test_truncate_datetime_simple(time_unit: TimeUnit) -> None:
81
s = pl.Series([datetime(2020, 1, 2, 6)], dtype=pl.Datetime(time_unit))
82
result = s.dt.truncate("1mo").item()
83
assert result == datetime(2020, 1, 1)
84
result = s.dt.truncate("1d").item()
85
assert result == datetime(2020, 1, 2)
86
87
88
@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])
89
def test_truncate_datetime_w_expression(time_unit: TimeUnit) -> None:
90
df = pl.DataFrame(
91
{"a": [datetime(2020, 1, 2, 6), datetime(2020, 1, 3, 7)], "b": ["1mo", "1d"]},
92
schema_overrides={"a": pl.Datetime(time_unit)},
93
)
94
result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"]
95
assert result[0] == datetime(2020, 1, 1)
96
assert result[1] == datetime(2020, 1, 3)
97
98
99
def test_pre_epoch_truncate_17581() -> None:
100
s = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1, 1)])
101
result = s.dt.truncate("1d")
102
expected = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1)])
103
assert_series_equal(result, expected)
104
105
106
@given(
107
datetimes=st.lists(
108
st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1)),
109
min_size=1,
110
max_size=3,
111
),
112
every=st.timedeltas(
113
min_value=timedelta(microseconds=1), max_value=timedelta(days=1)
114
).map(parse_as_duration_string),
115
)
116
def test_fast_path_vs_slow_path(datetimes: list[datetime], every: str) -> None:
117
s = pl.Series(datetimes)
118
# Might use fastpath:
119
result = s.dt.truncate(every)
120
# Definitely uses slowpath:
121
expected = s.dt.truncate(pl.Series([every] * len(datetimes)))
122
assert_series_equal(result, expected)
123
124
125
@pytest.mark.parametrize("as_date", [False, True])
126
def test_truncate_unequal_length_22018(as_date: bool) -> None:
127
s = pl.Series([datetime(2088, 8, 8, 8, 8, 8, 8)] * 2)
128
if as_date:
129
s = s.dt.date()
130
with pytest.raises(pl.exceptions.ShapeError):
131
s.dt.truncate(pl.Series(["1y"] * 3))
132
133
134
@pytest.mark.parametrize(
135
("multiplier", "unit", "value", "expected"),
136
[
137
(2, "h", datetime(1970, 1, 2, 3), datetime(1970, 1, 2, 2)),
138
(5, "h", datetime(1983, 3, 1, 4), datetime(1983, 3, 1, 2)),
139
(3, "d", datetime(1970, 1, 1), datetime(1970, 1, 1)),
140
(7, "d", datetime(2001, 1, 4, 5), datetime(2001, 1, 4)),
141
(11, "q", datetime(1, 9, 9, 9), datetime(1, 1, 1)),
142
(3, "y", datetime(1970, 1, 1), datetime(1970, 1, 1)),
143
(19, "y", datetime(9543, 1, 5, 6), datetime(9532, 1, 1)),
144
(5, "mo", datetime(1342, 11, 11, 11), datetime(1342, 7, 1)),
145
],
146
)
147
@pytest.mark.parametrize("time_zone", ["Asia/Kathmandu", None])
148
def test_truncate_origin_22590(
149
multiplier: int,
150
unit: str,
151
value: datetime,
152
expected: datetime,
153
time_zone: str | None,
154
) -> None:
155
result = (
156
pl.Series([value])
157
.dt.replace_time_zone(time_zone)
158
.dt.truncate(f"{multiplier}{unit}")
159
.dt.replace_time_zone(None)
160
.item()
161
)
162
assert result == expected, result
163
164
165
def test_truncate_invalid() -> None:
166
s = pl.Series([date(2020, 1, 1)])
167
with pytest.raises(ComputeError, match="cannot mix"):
168
s.dt.truncate("1d1h")
169
170