Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/business.rs
6939 views
1
#[cfg(feature = "dtype-date")]
2
use chrono::DateTime;
3
use polars_core::prelude::arity::{binary_elementwise_values, try_binary_elementwise};
4
use polars_core::prelude::*;
5
#[cfg(feature = "dtype-date")]
6
use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;
7
use polars_utils::binary_search::{find_first_ge_index, find_first_gt_index};
8
#[cfg(feature = "serde")]
9
use serde::{Deserialize, Serialize};
10
11
#[cfg(feature = "timezones")]
12
use crate::prelude::replace_time_zone;
13
14
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
15
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
17
pub enum Roll {
18
Forward,
19
Backward,
20
Raise,
21
}
22
23
/// Count the number of business days between `start` and `end`, excluding `end`.
24
///
25
/// # Arguments
26
/// - `start`: Series holding start dates.
27
/// - `end`: Series holding end dates.
28
/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.
29
/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of
30
/// days since the UNIX epoch.
31
pub fn business_day_count(
32
start: &Series,
33
end: &Series,
34
week_mask: [bool; 7],
35
holidays: &[i32],
36
) -> PolarsResult<Series> {
37
if !week_mask.iter().any(|&x| x) {
38
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
39
}
40
41
// Sort now so we can use `binary_search` in the hot for-loop.
42
let holidays = normalise_holidays(holidays, &week_mask);
43
let start_dates = start.date()?;
44
let end_dates = end.date()?;
45
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;
46
47
let out = match (start_dates.len(), end_dates.len()) {
48
(_, 1) => {
49
if let Some(end_date) = end_dates.physical().get(0) {
50
start_dates.physical().apply_values(|start_date| {
51
business_day_count_impl(
52
start_date,
53
end_date,
54
&week_mask,
55
n_business_days_in_week_mask,
56
&holidays,
57
)
58
})
59
} else {
60
Int32Chunked::full_null(start_dates.name().clone(), start_dates.len())
61
}
62
},
63
(1, _) => {
64
if let Some(start_date) = start_dates.physical().get(0) {
65
end_dates.physical().apply_values(|end_date| {
66
business_day_count_impl(
67
start_date,
68
end_date,
69
&week_mask,
70
n_business_days_in_week_mask,
71
&holidays,
72
)
73
})
74
} else {
75
Int32Chunked::full_null(start_dates.name().clone(), end_dates.len())
76
}
77
},
78
_ => {
79
polars_ensure!(
80
start_dates.len() == end_dates.len(),
81
length_mismatch = "business_day_count",
82
start_dates.len(),
83
end_dates.len()
84
);
85
binary_elementwise_values(
86
start_dates.physical(),
87
end_dates.physical(),
88
|start_date, end_date| {
89
business_day_count_impl(
90
start_date,
91
end_date,
92
&week_mask,
93
n_business_days_in_week_mask,
94
&holidays,
95
)
96
},
97
)
98
},
99
};
100
let out = out.with_name(start_dates.name().clone());
101
Ok(out.into_series())
102
}
103
104
/// Ported from:
105
/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L355-L433
106
fn business_day_count_impl(
107
mut start_date: i32,
108
mut end_date: i32,
109
week_mask: &[bool; 7],
110
n_business_days_in_week_mask: i32,
111
holidays: &[i32], // Caller's responsibility to ensure it's sorted.
112
) -> i32 {
113
let swapped = start_date > end_date;
114
if swapped {
115
(start_date, end_date) = (end_date, start_date);
116
start_date += 1;
117
end_date += 1;
118
}
119
120
let holidays_begin = find_first_ge_index(holidays, start_date);
121
let holidays_end = find_first_ge_index(&holidays[holidays_begin..], end_date) + holidays_begin;
122
let mut start_day_of_week = get_day_of_week(start_date);
123
let diff = end_date - start_date;
124
let whole_weeks = diff / 7;
125
let mut count = -((holidays_end - holidays_begin) as i32);
126
count += whole_weeks * n_business_days_in_week_mask;
127
start_date += whole_weeks * 7;
128
while start_date < end_date {
129
// SAFETY: week_mask is length 7, start_day_of_week is between 0 and 6
130
if unsafe { *week_mask.get_unchecked(start_day_of_week) } {
131
count += 1;
132
}
133
start_date += 1;
134
start_day_of_week = increment_day_of_week(start_day_of_week);
135
}
136
if swapped { -count } else { count }
137
}
138
139
/// Add a given number of business days.
140
///
141
/// # Arguments
142
/// - `start`: Series holding start dates.
143
/// - `n`: Number of business days to add.
144
/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.
145
/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of
146
/// days since the UNIX epoch.
147
/// - `roll`: what to do when the start date doesn't land on a business day:
148
/// - `Roll::Forward`: roll forward to the next business day.
149
/// - `Roll::Backward`: roll backward to the previous business day.
150
/// - `Roll::Raise`: raise an error.
151
pub fn add_business_days(
152
start: &Series,
153
n: &Series,
154
week_mask: [bool; 7],
155
holidays: &[i32],
156
roll: Roll,
157
) -> PolarsResult<Series> {
158
if !week_mask.iter().any(|&x| x) {
159
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
160
}
161
162
match start.dtype() {
163
DataType::Date => {},
164
#[cfg(feature = "dtype-datetime")]
165
DataType::Datetime(time_unit, None) => {
166
let result_date =
167
add_business_days(&start.cast(&DataType::Date)?, n, week_mask, holidays, roll)?;
168
let start_time = start
169
.cast(&DataType::Time)?
170
.cast(&DataType::Duration(*time_unit))?;
171
return std::ops::Add::add(
172
result_date.cast(&DataType::Datetime(*time_unit, None))?,
173
start_time,
174
);
175
},
176
#[cfg(feature = "timezones")]
177
DataType::Datetime(time_unit, Some(time_zone)) => {
178
let start_naive = replace_time_zone(
179
start.datetime().unwrap(),
180
None,
181
&StringChunked::from_iter(std::iter::once("raise")),
182
NonExistent::Raise,
183
)?;
184
let result_date = add_business_days(
185
&start_naive.cast(&DataType::Date)?,
186
n,
187
week_mask,
188
holidays,
189
roll,
190
)?;
191
let start_time = start_naive
192
.cast(&DataType::Time)?
193
.cast(&DataType::Duration(*time_unit))?;
194
let result_naive = std::ops::Add::add(
195
result_date.cast(&DataType::Datetime(*time_unit, None))?,
196
start_time,
197
)?;
198
let result_tz_aware = replace_time_zone(
199
result_naive.datetime().unwrap(),
200
Some(time_zone),
201
&StringChunked::from_iter(std::iter::once("raise")),
202
NonExistent::Raise,
203
)?;
204
return Ok(result_tz_aware.into_series());
205
},
206
_ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", start.dtype()),
207
}
208
209
// Sort now so we can use `binary_search` in the hot for-loop.
210
let holidays = normalise_holidays(holidays, &week_mask);
211
let start_dates = start.date()?;
212
let n = match &n.dtype() {
213
DataType::Int64 | DataType::UInt64 | DataType::UInt32 => n.cast(&DataType::Int32)?,
214
DataType::Int32 => n.clone(),
215
_ => {
216
polars_bail!(InvalidOperation: "expected Int64, Int32, UInt64, or UInt32, got {}", n.dtype())
217
},
218
};
219
let n = n.i32()?;
220
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;
221
222
let out: Int32Chunked = match (start_dates.len(), n.len()) {
223
(_, 1) => {
224
if let Some(n) = n.get(0) {
225
start_dates
226
.physical()
227
.try_apply_nonnull_values_generic(|start_date| {
228
let (start_date, day_of_week) =
229
roll_start_date(start_date, roll, &week_mask, &holidays)?;
230
Ok::<i32, PolarsError>(add_business_days_impl(
231
start_date,
232
day_of_week,
233
n,
234
&week_mask,
235
n_business_days_in_week_mask,
236
&holidays,
237
))
238
})?
239
} else {
240
Int32Chunked::full_null(start_dates.name().clone(), start_dates.len())
241
}
242
},
243
(1, _) => {
244
if let Some(start_date) = start_dates.physical().get(0) {
245
let (start_date, day_of_week) =
246
roll_start_date(start_date, roll, &week_mask, &holidays)?;
247
n.apply_values(|n| {
248
add_business_days_impl(
249
start_date,
250
day_of_week,
251
n,
252
&week_mask,
253
n_business_days_in_week_mask,
254
&holidays,
255
)
256
})
257
} else {
258
Int32Chunked::full_null(start_dates.name().clone(), n.len())
259
}
260
},
261
_ => {
262
polars_ensure!(
263
start_dates.len() == n.len(),
264
length_mismatch = "dt.add_business_days",
265
start_dates.len(),
266
n.len()
267
);
268
try_binary_elementwise(start_dates.physical(), n, |opt_start_date, opt_n| {
269
match (opt_start_date, opt_n) {
270
(Some(start_date), Some(n)) => {
271
let (start_date, day_of_week) =
272
roll_start_date(start_date, roll, &week_mask, &holidays)?;
273
Ok::<Option<i32>, PolarsError>(Some(add_business_days_impl(
274
start_date,
275
day_of_week,
276
n,
277
&week_mask,
278
n_business_days_in_week_mask,
279
&holidays,
280
)))
281
},
282
_ => Ok(None),
283
}
284
})?
285
},
286
};
287
Ok(out.into_date().into_series())
288
}
289
290
/// Ported from:
291
/// https://github.com/numpy/numpy/blob/e59c074842e3f73483afa5ddef031e856b9fd313/numpy/_core/src/multiarray/datetime_busday.c#L265-L353
292
fn add_business_days_impl(
293
mut date: i32,
294
mut day_of_week: usize,
295
mut n: i32,
296
week_mask: &[bool; 7],
297
n_business_days_in_week_mask: i32,
298
holidays: &[i32], // Caller's responsibility to ensure it's sorted.
299
) -> i32 {
300
if n > 0 {
301
let holidays_begin = find_first_ge_index(holidays, date);
302
date += (n / n_business_days_in_week_mask) * 7;
303
n %= n_business_days_in_week_mask;
304
let holidays_temp = find_first_gt_index(&holidays[holidays_begin..], date) + holidays_begin;
305
n += (holidays_temp - holidays_begin) as i32;
306
let holidays_begin = holidays_temp;
307
while n > 0 {
308
date += 1;
309
day_of_week = increment_day_of_week(day_of_week);
310
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
311
if unsafe {
312
(*week_mask.get_unchecked(day_of_week))
313
&& (holidays[holidays_begin..].binary_search(&date).is_err())
314
} {
315
n -= 1;
316
}
317
}
318
date
319
} else {
320
let holidays_end = find_first_gt_index(holidays, date);
321
date += (n / n_business_days_in_week_mask) * 7;
322
n %= n_business_days_in_week_mask;
323
let holidays_temp = find_first_ge_index(&holidays[..holidays_end], date);
324
n -= (holidays_end - holidays_temp) as i32;
325
let holidays_end = holidays_temp;
326
while n < 0 {
327
date -= 1;
328
day_of_week = decrement_day_of_week(day_of_week);
329
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
330
if unsafe {
331
(*week_mask.get_unchecked(day_of_week))
332
&& (holidays[..holidays_end].binary_search(&date).is_err())
333
} {
334
n += 1;
335
}
336
}
337
date
338
}
339
}
340
341
/// Determine if a day lands on a business day.
342
///
343
/// # Arguments
344
/// - `week_mask`: A boolean array of length 7, where `true` indicates that the day is a business day.
345
/// - `holidays`: timestamps that are holidays. Must be provided as i32, i.e. the number of
346
/// days since the UNIX epoch.
347
pub fn is_business_day(
348
dates: &Series,
349
week_mask: [bool; 7],
350
holidays: &[i32],
351
) -> PolarsResult<Series> {
352
if !week_mask.iter().any(|&x| x) {
353
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
354
}
355
356
match dates.dtype() {
357
DataType::Date => {},
358
#[cfg(feature = "dtype-datetime")]
359
DataType::Datetime(_, None) => {
360
return is_business_day(&dates.cast(&DataType::Date)?, week_mask, holidays);
361
},
362
#[cfg(feature = "timezones")]
363
DataType::Datetime(_, Some(_)) => {
364
let dates_local = replace_time_zone(
365
dates.datetime().unwrap(),
366
None,
367
&StringChunked::from_iter(std::iter::once("raise")),
368
NonExistent::Raise,
369
)?;
370
return is_business_day(&dates_local.cast(&DataType::Date)?, week_mask, holidays);
371
},
372
_ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", dates.dtype()),
373
}
374
375
// Sort now so we can use `binary_search` in the hot for-loop.
376
let holidays = normalise_holidays(holidays, &week_mask);
377
let dates = dates.date()?;
378
let out: BooleanChunked =
379
dates
380
.physical()
381
.apply_nonnull_values_generic(DataType::Boolean, |date| {
382
let day_of_week = get_day_of_week(date);
383
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
384
unsafe {
385
(*week_mask.get_unchecked(day_of_week))
386
&& holidays.binary_search(&date).is_err()
387
}
388
});
389
Ok(out.into_series())
390
}
391
392
fn roll_start_date(
393
mut date: i32,
394
roll: Roll,
395
week_mask: &[bool; 7],
396
holidays: &[i32], // Caller's responsibility to ensure it's sorted.
397
) -> PolarsResult<(i32, usize)> {
398
let mut day_of_week = get_day_of_week(date);
399
match roll {
400
Roll::Raise => {
401
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
402
if holidays.binary_search(&date).is_ok()
403
| unsafe { !*week_mask.get_unchecked(day_of_week) }
404
{
405
let date = DateTime::from_timestamp(date as i64 * SECONDS_IN_DAY, 0)
406
.unwrap()
407
.format("%Y-%m-%d");
408
polars_bail!(ComputeError:
409
"date {} is not a business date; use `roll` to roll forwards (or backwards) to the next (or previous) valid date.", date
410
)
411
};
412
},
413
Roll::Forward => {
414
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
415
while holidays.binary_search(&date).is_ok()
416
| unsafe { !*week_mask.get_unchecked(day_of_week) }
417
{
418
date += 1;
419
day_of_week = increment_day_of_week(day_of_week);
420
}
421
},
422
Roll::Backward => {
423
// SAFETY: week_mask is length 7, day_of_week is between 0 and 6
424
while holidays.binary_search(&date).is_ok()
425
| unsafe { !*week_mask.get_unchecked(day_of_week) }
426
{
427
date -= 1;
428
day_of_week = decrement_day_of_week(day_of_week);
429
}
430
},
431
}
432
Ok((date, day_of_week))
433
}
434
435
/// Sort and deduplicate holidays and remove holidays that are not business days.
436
fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec<i32> {
437
let mut holidays: Vec<i32> = holidays.to_vec();
438
holidays.sort_unstable();
439
let mut previous_holiday: Option<i32> = None;
440
holidays.retain(|&x| {
441
// SAFETY: week_mask is length 7, get_day_of_week result is between 0 and 6
442
if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(get_day_of_week(x)) }
443
{
444
return false;
445
}
446
previous_holiday = Some(x);
447
true
448
});
449
holidays
450
}
451
452
fn get_day_of_week(x: i32) -> usize {
453
// the first modulo might return a negative number, so we add 7 and take
454
// the modulo again so we're sure we have something between 0 (Monday)
455
// and 6 (Sunday)
456
(((x - 4) % 7 + 7) % 7) as usize
457
}
458
459
fn increment_day_of_week(x: usize) -> usize {
460
if x == 6 { 0 } else { x + 1 }
461
}
462
463
fn decrement_day_of_week(x: usize) -> usize {
464
if x == 0 { 6 } else { x - 1 }
465
}
466
467