Path: blob/main/py-polars/tests/unit/functions/test_business_day_count.py
6939 views
from __future__ import annotations12import datetime as dt3from datetime import date45import hypothesis.strategies as st6import numpy as np7import pytest8from hypothesis import assume, given, reject910import polars as pl11from polars._utils.various import parse_version12from polars.exceptions import ComputeError13from polars.testing import assert_series_equal141516def test_business_day_count() -> None:17# (Expression, expression)18df = pl.DataFrame(19{20"start": [date(2020, 1, 1), date(2020, 1, 2)],21"end": [date(2020, 1, 2), date(2020, 1, 10)],22}23)24result = df.select(25business_day_count=pl.business_day_count("start", "end"),26)["business_day_count"]27expected = pl.Series("business_day_count", [1, 6], pl.Int32)28assert_series_equal(result, expected)2930# (Expression, scalar)31result = df.select(32business_day_count=pl.business_day_count("start", date(2020, 1, 10)),33)["business_day_count"]34expected = pl.Series("business_day_count", [7, 6], pl.Int32)35assert_series_equal(result, expected)3637result = df.select(38business_day_count=pl.business_day_count("start", pl.lit(None, dtype=pl.Date)),39)["business_day_count"]40expected = pl.Series("business_day_count", [None, None], pl.Int32)41assert_series_equal(result, expected)4243# (Scalar, expression)44result = df.select(45business_day_count=pl.business_day_count(date(2020, 1, 1), "end"),46)["business_day_count"]47expected = pl.Series("business_day_count", [1, 7], pl.Int32)48assert_series_equal(result, expected)49# see GH issue #2366350assert df.lazy().select(51pl.business_day_count(date(2020, 1, 1), "end")52).collect_schema() == pl.Schema({"literal": pl.Int32})5354result = df.select(55business_day_count=pl.business_day_count(pl.lit(None, dtype=pl.Date), "end"),56)["business_day_count"]57expected = pl.Series("business_day_count", [None, None], pl.Int32)58assert_series_equal(result, expected)5960# (Scalar, scalar)61result = df.select(62business_day_count=pl.business_day_count(date(2020, 1, 1), date(2020, 1, 10)),63)["business_day_count"]64expected = pl.Series("business_day_count", [7], pl.Int32)65assert_series_equal(result, expected)666768def test_business_day_count_w_week_mask() -> None:69df = pl.DataFrame(70{71"start": [date(2020, 1, 1), date(2020, 1, 2)],72"end": [date(2020, 1, 2), date(2020, 1, 10)],73}74)75result = df.select(76business_day_count=pl.business_day_count(77"start", "end", week_mask=(True, True, True, True, True, True, False)78),79)["business_day_count"]80expected = pl.Series("business_day_count", [1, 7], pl.Int32)81assert_series_equal(result, expected)8283result = df.select(84business_day_count=pl.business_day_count(85"start", "end", week_mask=(True, True, True, False, False, False, True)86),87)["business_day_count"]88expected = pl.Series("business_day_count", [1, 4], pl.Int32)89assert_series_equal(result, expected)909192def test_business_day_count_w_week_mask_invalid() -> None:93with pytest.raises(ValueError, match=r"expected a sequence of length 7 \(got 2\)"):94pl.business_day_count("start", "end", week_mask=(False, 0)) # type: ignore[arg-type]95df = pl.DataFrame(96{97"start": [date(2020, 1, 1), date(2020, 1, 2)],98"end": [date(2020, 1, 2), date(2020, 1, 10)],99}100)101with pytest.raises(102ComputeError, match="`week_mask` must have at least one business day"103):104df.select(pl.business_day_count("start", "end", week_mask=[False] * 7))105106107def test_business_day_count_schema() -> None:108lf = pl.LazyFrame(109{110"start": [date(2020, 1, 1), date(2020, 1, 2)],111"end": [date(2020, 1, 2), date(2020, 1, 10)],112}113)114result = lf.select(115business_day_count=pl.business_day_count("start", "end"),116)117assert result.collect_schema()["business_day_count"] == pl.Int32118assert result.collect().schema["business_day_count"] == pl.Int32119assert 'col("start").business_day_count([col("end")])' in result.explain()120121122def test_business_day_count_w_holidays() -> None:123df = pl.DataFrame(124{125"start": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],126"end": [date(2020, 1, 2), date(2020, 1, 10), date(2020, 1, 9)],127}128)129result = df.select(130business_day_count=pl.business_day_count(131"start", "end", holidays=[date(2020, 1, 1), date(2020, 1, 9)]132),133)["business_day_count"]134expected = pl.Series("business_day_count", [0, 5, 5], pl.Int32)135assert_series_equal(result, expected)136137138@given(139start=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),140end=st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),141week_mask=st.lists(142st.sampled_from([True, False]),143min_size=7,144max_size=7,145),146holidays=st.lists(147st.dates(min_value=dt.date(1969, 1, 1), max_value=dt.date(1970, 12, 31)),148min_size=0,149max_size=100,150),151)152def test_against_np_busday_count(153start: dt.date, end: dt.date, week_mask: tuple[bool, ...], holidays: list[dt.date]154) -> None:155assume(any(week_mask))156result = (157pl.DataFrame({"start": [start], "end": [end]})158.select(159n=pl.business_day_count(160"start", "end", week_mask=week_mask, holidays=holidays161)162)["n"]163.item()164)165expected = np.busday_count(start, end, weekmask=week_mask, holidays=holidays)166if start > end and parse_version(np.__version__) < (1, 25):167# Bug in old versions of numpy168reject()169assert result == expected170171172def test_unequal_length_22018() -> None:173with pytest.raises(pl.exceptions.ShapeError):174pl.select(175pl.business_day_count(176pl.Series([date(2020, 1, 1)] * 2),177pl.Series([date(2020, 1, 1)] * 3),178)179)180181182