Path: blob/main/crates/polars-ops/src/series/ops/business.rs
6939 views
#[cfg(feature = "dtype-date")]1use chrono::DateTime;2use polars_core::prelude::arity::{binary_elementwise_values, try_binary_elementwise};3use polars_core::prelude::*;4#[cfg(feature = "dtype-date")]5use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;6use polars_utils::binary_search::{find_first_ge_index, find_first_gt_index};7#[cfg(feature = "serde")]8use serde::{Deserialize, Serialize};910#[cfg(feature = "timezones")]11use crate::prelude::replace_time_zone;1213#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]15#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]16pub enum Roll {17Forward,18Backward,19Raise,20}2122/// Count the number of business days between `start` and `end`, excluding `end`.23///24/// # Arguments25/// - `start`: Series holding start dates.26/// - `end`: Series holding end dates.27/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.28/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of29/// days since the UNIX epoch.30pub fn business_day_count(31start: &Series,32end: &Series,33week_mask: [bool; 7],34holidays: &[i32],35) -> PolarsResult<Series> {36if !week_mask.iter().any(|&x| x) {37polars_bail!(ComputeError:"`week_mask` must have at least one business day");38}3940// Sort now so we can use `binary_search` in the hot for-loop.41let holidays = normalise_holidays(holidays, &week_mask);42let start_dates = start.date()?;43let end_dates = end.date()?;44let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;4546let out = match (start_dates.len(), end_dates.len()) {47(_, 1) => {48if let Some(end_date) = end_dates.physical().get(0) {49start_dates.physical().apply_values(|start_date| {50business_day_count_impl(51start_date,52end_date,53&week_mask,54n_business_days_in_week_mask,55&holidays,56)57})58} else {59Int32Chunked::full_null(start_dates.name().clone(), start_dates.len())60}61},62(1, _) => {63if let Some(start_date) = start_dates.physical().get(0) {64end_dates.physical().apply_values(|end_date| {65business_day_count_impl(66start_date,67end_date,68&week_mask,69n_business_days_in_week_mask,70&holidays,71)72})73} else {74Int32Chunked::full_null(start_dates.name().clone(), end_dates.len())75}76},77_ => {78polars_ensure!(79start_dates.len() == end_dates.len(),80length_mismatch = "business_day_count",81start_dates.len(),82end_dates.len()83);84binary_elementwise_values(85start_dates.physical(),86end_dates.physical(),87|start_date, end_date| {88business_day_count_impl(89start_date,90end_date,91&week_mask,92n_business_days_in_week_mask,93&holidays,94)95},96)97},98};99let out = out.with_name(start_dates.name().clone());100Ok(out.into_series())101}102103/// Ported from:104/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L355-L433105fn business_day_count_impl(106mut start_date: i32,107mut end_date: i32,108week_mask: &[bool; 7],109n_business_days_in_week_mask: i32,110holidays: &[i32], // Caller's responsibility to ensure it's sorted.111) -> i32 {112let swapped = start_date > end_date;113if swapped {114(start_date, end_date) = (end_date, start_date);115start_date += 1;116end_date += 1;117}118119let holidays_begin = find_first_ge_index(holidays, start_date);120let holidays_end = find_first_ge_index(&holidays[holidays_begin..], end_date) + holidays_begin;121let mut start_day_of_week = get_day_of_week(start_date);122let diff = end_date - start_date;123let whole_weeks = diff / 7;124let mut count = -((holidays_end - holidays_begin) as i32);125count += whole_weeks * n_business_days_in_week_mask;126start_date += whole_weeks * 7;127while start_date < end_date {128// SAFETY: week_mask is length 7, start_day_of_week is between 0 and 6129if unsafe { *week_mask.get_unchecked(start_day_of_week) } {130count += 1;131}132start_date += 1;133start_day_of_week = increment_day_of_week(start_day_of_week);134}135if swapped { -count } else { count }136}137138/// Add a given number of business days.139///140/// # Arguments141/// - `start`: Series holding start dates.142/// - `n`: Number of business days to add.143/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.144/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of145/// days since the UNIX epoch.146/// - `roll`: what to do when the start date doesn't land on a business day:147/// - `Roll::Forward`: roll forward to the next business day.148/// - `Roll::Backward`: roll backward to the previous business day.149/// - `Roll::Raise`: raise an error.150pub fn add_business_days(151start: &Series,152n: &Series,153week_mask: [bool; 7],154holidays: &[i32],155roll: Roll,156) -> PolarsResult<Series> {157if !week_mask.iter().any(|&x| x) {158polars_bail!(ComputeError:"`week_mask` must have at least one business day");159}160161match start.dtype() {162DataType::Date => {},163#[cfg(feature = "dtype-datetime")]164DataType::Datetime(time_unit, None) => {165let result_date =166add_business_days(&start.cast(&DataType::Date)?, n, week_mask, holidays, roll)?;167let start_time = start168.cast(&DataType::Time)?169.cast(&DataType::Duration(*time_unit))?;170return std::ops::Add::add(171result_date.cast(&DataType::Datetime(*time_unit, None))?,172start_time,173);174},175#[cfg(feature = "timezones")]176DataType::Datetime(time_unit, Some(time_zone)) => {177let start_naive = replace_time_zone(178start.datetime().unwrap(),179None,180&StringChunked::from_iter(std::iter::once("raise")),181NonExistent::Raise,182)?;183let result_date = add_business_days(184&start_naive.cast(&DataType::Date)?,185n,186week_mask,187holidays,188roll,189)?;190let start_time = start_naive191.cast(&DataType::Time)?192.cast(&DataType::Duration(*time_unit))?;193let result_naive = std::ops::Add::add(194result_date.cast(&DataType::Datetime(*time_unit, None))?,195start_time,196)?;197let result_tz_aware = replace_time_zone(198result_naive.datetime().unwrap(),199Some(time_zone),200&StringChunked::from_iter(std::iter::once("raise")),201NonExistent::Raise,202)?;203return Ok(result_tz_aware.into_series());204},205_ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", start.dtype()),206}207208// Sort now so we can use `binary_search` in the hot for-loop.209let holidays = normalise_holidays(holidays, &week_mask);210let start_dates = start.date()?;211let n = match &n.dtype() {212DataType::Int64 | DataType::UInt64 | DataType::UInt32 => n.cast(&DataType::Int32)?,213DataType::Int32 => n.clone(),214_ => {215polars_bail!(InvalidOperation: "expected Int64, Int32, UInt64, or UInt32, got {}", n.dtype())216},217};218let n = n.i32()?;219let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;220221let out: Int32Chunked = match (start_dates.len(), n.len()) {222(_, 1) => {223if let Some(n) = n.get(0) {224start_dates225.physical()226.try_apply_nonnull_values_generic(|start_date| {227let (start_date, day_of_week) =228roll_start_date(start_date, roll, &week_mask, &holidays)?;229Ok::<i32, PolarsError>(add_business_days_impl(230start_date,231day_of_week,232n,233&week_mask,234n_business_days_in_week_mask,235&holidays,236))237})?238} else {239Int32Chunked::full_null(start_dates.name().clone(), start_dates.len())240}241},242(1, _) => {243if let Some(start_date) = start_dates.physical().get(0) {244let (start_date, day_of_week) =245roll_start_date(start_date, roll, &week_mask, &holidays)?;246n.apply_values(|n| {247add_business_days_impl(248start_date,249day_of_week,250n,251&week_mask,252n_business_days_in_week_mask,253&holidays,254)255})256} else {257Int32Chunked::full_null(start_dates.name().clone(), n.len())258}259},260_ => {261polars_ensure!(262start_dates.len() == n.len(),263length_mismatch = "dt.add_business_days",264start_dates.len(),265n.len()266);267try_binary_elementwise(start_dates.physical(), n, |opt_start_date, opt_n| {268match (opt_start_date, opt_n) {269(Some(start_date), Some(n)) => {270let (start_date, day_of_week) =271roll_start_date(start_date, roll, &week_mask, &holidays)?;272Ok::<Option<i32>, PolarsError>(Some(add_business_days_impl(273start_date,274day_of_week,275n,276&week_mask,277n_business_days_in_week_mask,278&holidays,279)))280},281_ => Ok(None),282}283})?284},285};286Ok(out.into_date().into_series())287}288289/// Ported from:290/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L265-L353291fn add_business_days_impl(292mut date: i32,293mut day_of_week: usize,294mut n: i32,295week_mask: &[bool; 7],296n_business_days_in_week_mask: i32,297holidays: &[i32], // Caller's responsibility to ensure it's sorted.298) -> i32 {299if n > 0 {300let holidays_begin = find_first_ge_index(holidays, date);301date += (n / n_business_days_in_week_mask) * 7;302n %= n_business_days_in_week_mask;303let holidays_temp = find_first_gt_index(&holidays[holidays_begin..], date) + holidays_begin;304n += (holidays_temp - holidays_begin) as i32;305let holidays_begin = holidays_temp;306while n > 0 {307date += 1;308day_of_week = increment_day_of_week(day_of_week);309// SAFETY: week_mask is length 7, day_of_week is between 0 and 6310if unsafe {311(*week_mask.get_unchecked(day_of_week))312&& (holidays[holidays_begin..].binary_search(&date).is_err())313} {314n -= 1;315}316}317date318} else {319let holidays_end = find_first_gt_index(holidays, date);320date += (n / n_business_days_in_week_mask) * 7;321n %= n_business_days_in_week_mask;322let holidays_temp = find_first_ge_index(&holidays[..holidays_end], date);323n -= (holidays_end - holidays_temp) as i32;324let holidays_end = holidays_temp;325while n < 0 {326date -= 1;327day_of_week = decrement_day_of_week(day_of_week);328// SAFETY: week_mask is length 7, day_of_week is between 0 and 6329if unsafe {330(*week_mask.get_unchecked(day_of_week))331&& (holidays[..holidays_end].binary_search(&date).is_err())332} {333n += 1;334}335}336date337}338}339340/// Determine if a day lands on a business day.341///342/// # Arguments343/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.344/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of345/// days since the UNIX epoch.346pub fn is_business_day(347dates: &Series,348week_mask: [bool; 7],349holidays: &[i32],350) -> PolarsResult<Series> {351if !week_mask.iter().any(|&x| x) {352polars_bail!(ComputeError:"`week_mask` must have at least one business day");353}354355match dates.dtype() {356DataType::Date => {},357#[cfg(feature = "dtype-datetime")]358DataType::Datetime(_, None) => {359return is_business_day(&dates.cast(&DataType::Date)?, week_mask, holidays);360},361#[cfg(feature = "timezones")]362DataType::Datetime(_, Some(_)) => {363let dates_local = replace_time_zone(364dates.datetime().unwrap(),365None,366&StringChunked::from_iter(std::iter::once("raise")),367NonExistent::Raise,368)?;369return is_business_day(&dates_local.cast(&DataType::Date)?, week_mask, holidays);370},371_ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", dates.dtype()),372}373374// Sort now so we can use `binary_search` in the hot for-loop.375let holidays = normalise_holidays(holidays, &week_mask);376let dates = dates.date()?;377let out: BooleanChunked =378dates379.physical()380.apply_nonnull_values_generic(DataType::Boolean, |date| {381let day_of_week = get_day_of_week(date);382// SAFETY: week_mask is length 7, day_of_week is between 0 and 6383unsafe {384(*week_mask.get_unchecked(day_of_week))385&& holidays.binary_search(&date).is_err()386}387});388Ok(out.into_series())389}390391fn roll_start_date(392mut date: i32,393roll: Roll,394week_mask: &[bool; 7],395holidays: &[i32], // Caller's responsibility to ensure it's sorted.396) -> PolarsResult<(i32, usize)> {397let mut day_of_week = get_day_of_week(date);398match roll {399Roll::Raise => {400// SAFETY: week_mask is length 7, day_of_week is between 0 and 6401if holidays.binary_search(&date).is_ok()402| unsafe { !*week_mask.get_unchecked(day_of_week) }403{404let date = DateTime::from_timestamp(date as i64 * SECONDS_IN_DAY, 0)405.unwrap()406.format("%Y-%m-%d");407polars_bail!(ComputeError:408"date {} is not a business date; use `roll` to roll forwards (or backwards) to the next (or previous) valid date.", date409)410};411},412Roll::Forward => {413// SAFETY: week_mask is length 7, day_of_week is between 0 and 6414while holidays.binary_search(&date).is_ok()415| unsafe { !*week_mask.get_unchecked(day_of_week) }416{417date += 1;418day_of_week = increment_day_of_week(day_of_week);419}420},421Roll::Backward => {422// SAFETY: week_mask is length 7, day_of_week is between 0 and 6423while holidays.binary_search(&date).is_ok()424| unsafe { !*week_mask.get_unchecked(day_of_week) }425{426date -= 1;427day_of_week = decrement_day_of_week(day_of_week);428}429},430}431Ok((date, day_of_week))432}433434/// Sort and deduplicate holidays and remove holidays that are not business days.435fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec<i32> {436let mut holidays: Vec<i32> = holidays.to_vec();437holidays.sort_unstable();438let mut previous_holiday: Option<i32> = None;439holidays.retain(|&x| {440// SAFETY: week_mask is length 7, get_day_of_week result is between 0 and 6441if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(get_day_of_week(x)) }442{443return false;444}445previous_holiday = Some(x);446true447});448holidays449}450451fn get_day_of_week(x: i32) -> usize {452// the first modulo might return a negative number, so we add 7 and take453// the modulo again so we're sure we have something between 0 (Monday)454// and 6 (Sunday)455(((x - 4) % 7 + 7) % 7) as usize456}457458fn increment_day_of_week(x: usize) -> usize {459if x == 6 { 0 } else { x + 1 }460}461462fn decrement_day_of_week(x: usize) -> usize {463if x == 0 { 6 } else { x - 1 }464}465466467