Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-testing/src/asserts/utils.rs
6940 views
1
use std::ops::Not;
2
3
use polars_core::datatypes::unpack_dtypes;
4
use polars_core::prelude::*;
5
use polars_ops::series::is_close;
6
7
/// Configuration options for comparing Series equality.
8
///
9
/// Controls the behavior of Series equality comparisons by specifying
10
/// which aspects to check and the tolerance for floating point comparisons.
11
pub struct SeriesEqualOptions {
12
/// Whether to check that the data types match.
13
pub check_dtypes: bool,
14
/// Whether to check that the Series names match.
15
pub check_names: bool,
16
/// Whether to check that elements appear in the same order.
17
pub check_order: bool,
18
/// Whether to check for exact equality (true) or approximate equality (false) for floating point values.
19
pub check_exact: bool,
20
/// Relative tolerance for approximate equality of floating point values.
21
pub rel_tol: f64,
22
/// Absolute tolerance for approximate equality of floating point values.
23
pub abs_tol: f64,
24
/// Whether to compare categorical values as strings.
25
pub categorical_as_str: bool,
26
}
27
28
impl Default for SeriesEqualOptions {
29
/// Creates a new `SeriesEqualOptions` with default settings.
30
///
31
/// Default configuration:
32
/// - Checks data types, names, and order
33
/// - Uses exact equality comparisons
34
/// - Sets relative tolerance to 1e-5 and absolute tolerance to 1e-8 for floating point comparisons
35
/// - Does not convert categorical values to strings for comparison
36
fn default() -> Self {
37
Self {
38
check_dtypes: true,
39
check_names: true,
40
check_order: true,
41
check_exact: true,
42
rel_tol: 1e-5,
43
abs_tol: 1e-8,
44
categorical_as_str: false,
45
}
46
}
47
}
48
49
impl SeriesEqualOptions {
50
/// Creates a new `SeriesEqualOptions` with default settings.
51
pub fn new() -> Self {
52
Self::default()
53
}
54
55
/// Sets whether to check that data types match.
56
pub fn with_check_dtypes(mut self, value: bool) -> Self {
57
self.check_dtypes = value;
58
self
59
}
60
61
/// Sets whether to check that Series names match.
62
pub fn with_check_names(mut self, value: bool) -> Self {
63
self.check_names = value;
64
self
65
}
66
67
/// Sets whether to check that elements appear in the same order.
68
pub fn with_check_order(mut self, value: bool) -> Self {
69
self.check_order = value;
70
self
71
}
72
73
/// Sets whether to check for exact equality (true) or approximate equality (false) for floating point values.
74
pub fn with_check_exact(mut self, value: bool) -> Self {
75
self.check_exact = value;
76
self
77
}
78
79
/// Sets the relative tolerance for approximate equality of floating point values.
80
pub fn with_rel_tol(mut self, value: f64) -> Self {
81
self.rel_tol = value;
82
self
83
}
84
85
/// Sets the absolute tolerance for approximate equality of floating point values.
86
pub fn with_abs_tol(mut self, value: f64) -> Self {
87
self.abs_tol = value;
88
self
89
}
90
91
/// Sets whether to compare categorical values as strings.
92
pub fn with_categorical_as_str(mut self, value: bool) -> Self {
93
self.categorical_as_str = value;
94
self
95
}
96
}
97
98
/// Change a (possibly nested) Categorical data type to a String data type.
99
fn categorical_dtype_to_string_dtype(dtype: &DataType) -> DataType {
100
match dtype {
101
DataType::Categorical(..) => DataType::String,
102
DataType::List(inner) => {
103
let inner_cast = categorical_dtype_to_string_dtype(inner);
104
DataType::List(Box::new(inner_cast))
105
},
106
DataType::Array(inner, size) => {
107
let inner_cast = categorical_dtype_to_string_dtype(inner);
108
DataType::Array(Box::new(inner_cast), *size)
109
},
110
DataType::Struct(fields) => {
111
let transformed_fields = fields
112
.iter()
113
.map(|field| {
114
Field::new(
115
field.name().clone(),
116
categorical_dtype_to_string_dtype(field.dtype()),
117
)
118
})
119
.collect::<Vec<Field>>();
120
121
DataType::Struct(transformed_fields)
122
},
123
_ => dtype.clone(),
124
}
125
}
126
127
/// Cast a (possibly nested) Categorical Series to a String Series.
128
fn categorical_series_to_string(s: &Series) -> PolarsResult<Series> {
129
let dtype = s.dtype();
130
let noncat_dtype = categorical_dtype_to_string_dtype(dtype);
131
132
if *dtype != noncat_dtype {
133
Ok(s.cast(&noncat_dtype)?)
134
} else {
135
Ok(s.clone())
136
}
137
}
138
139
/// Returns true if both DataTypes are floating point types.
140
fn are_both_floats(left: &DataType, right: &DataType) -> bool {
141
left.is_float() && right.is_float()
142
}
143
144
/// Returns true if both DataTypes are list-like (either List or Array types).
145
fn are_both_lists(left: &DataType, right: &DataType) -> bool {
146
matches!(left, DataType::List(_) | DataType::Array(_, _))
147
&& matches!(right, DataType::List(_) | DataType::Array(_, _))
148
}
149
150
/// Returns true if both DataTypes are struct types.
151
fn are_both_structs(left: &DataType, right: &DataType) -> bool {
152
left.is_struct() && right.is_struct()
153
}
154
155
/// Returns true if both DataTypes are nested types (lists or structs) that contain floating point types within them.
156
/// First checks if both types are either lists or structs, then unpacks their nested DataTypes to determine if
157
/// at least one floating point type exists in each of the nested structures.
158
fn comparing_nested_floats(left: &DataType, right: &DataType) -> bool {
159
if !are_both_lists(left, right) && !are_both_structs(left, right) {
160
return false;
161
}
162
163
let left_dtypes = unpack_dtypes(left, false);
164
let right_dtypes = unpack_dtypes(right, false);
165
166
let left_has_floats = left_dtypes.iter().any(|dt| dt.is_float());
167
let right_has_floats = right_dtypes.iter().any(|dt| dt.is_float());
168
169
left_has_floats && right_has_floats
170
}
171
172
/// Ensures that null values in two Series match exactly and returns an error if any mismatches are found.
173
fn assert_series_null_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
174
let null_value_mismatch = left.is_null().not_equal(&right.is_null());
175
176
if null_value_mismatch.any() {
177
return Err(polars_err!(
178
assertion_error = "Series",
179
"null value mismatch",
180
left.null_count(),
181
right.null_count()
182
));
183
}
184
185
Ok(())
186
}
187
188
/// Validates that NaN patterns are identical between two float Series, returning error if any mismatches are found.
189
fn assert_series_nan_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
190
if !are_both_floats(left.dtype(), right.dtype()) {
191
return Ok(());
192
}
193
let left_nan = left.is_nan()?;
194
let right_nan = right.is_nan()?;
195
196
let nan_value_mismatch = left_nan.not_equal(&right_nan);
197
198
let left_nan_count = left_nan.sum().unwrap_or(0);
199
let right_nan_count = right_nan.sum().unwrap_or(0);
200
201
if nan_value_mismatch.any() {
202
return Err(polars_err!(
203
assertion_error = "Series",
204
"nan value mismatch",
205
left_nan_count,
206
right_nan_count
207
));
208
}
209
210
Ok(())
211
}
212
213
/// Verifies that two Series have values within a specified tolerance.
214
///
215
/// This function checks if the values in `left` and `right` Series that are marked as unequal
216
/// in the `unequal` boolean array are within the specified relative and absolute tolerances.
217
///
218
/// # Arguments
219
///
220
/// * `left` - The first Series to compare
221
/// * `right` - The second Series to compare
222
/// * `unequal` - Boolean ChunkedArray indicating which elements to check (true = check this element)
223
/// * `rel_tol` - Relative tolerance (relative to the maximum absolute value of the two Series)
224
/// * `abs_tol` - Absolute tolerance added to the relative tolerance
225
///
226
/// # Returns
227
///
228
/// * `Ok(())` if all values are within tolerance
229
/// * `Err` with details about problematic values if any values exceed the tolerance
230
///
231
/// # Formula
232
///
233
/// Values are considered within tolerance if:
234
/// `|left - right| <= max(rel_tol * max(abs(left), abs(right)), abs_tol)` OR values are exactly equal
235
///
236
fn assert_series_values_within_tolerance(
237
left: &Series,
238
right: &Series,
239
unequal: &ChunkedArray<BooleanType>,
240
rel_tol: f64,
241
abs_tol: f64,
242
) -> PolarsResult<()> {
243
let left_unequal = left.filter(unequal)?;
244
let right_unequal = right.filter(unequal)?;
245
246
let within_tolerance = is_close(&left_unequal, &right_unequal, abs_tol, rel_tol, false)?;
247
if within_tolerance.all() {
248
Ok(())
249
} else {
250
let exceeded_indices = within_tolerance.not();
251
let problematic_left = left_unequal.filter(&exceeded_indices)?;
252
let problematic_right = right_unequal.filter(&exceeded_indices)?;
253
254
Err(polars_err!(
255
assertion_error = "Series",
256
"values not within tolerance",
257
problematic_left,
258
problematic_right
259
))
260
}
261
}
262
263
/// Compares two Series for equality with configurable options for ordering, exact matching, and tolerance.
264
///
265
/// This function verifies that the values in `left` and `right` Series are equal according to
266
/// the specified comparison criteria. It handles different types including floats and nested types
267
/// with appropriate equality checks.
268
///
269
/// # Arguments
270
///
271
/// * `left` - The first Series to compare
272
/// * `right` - The second Series to compare
273
/// * `check_order` - If true, elements must be in the same order; if false, Series will be sorted before comparison
274
/// * `check_exact` - If true, requires exact equality; if false, allows approximate equality for floats within tolerance
275
/// * `rel_tol` - Relative tolerance for float comparison (used when `check_exact` is false)
276
/// * `abs_tol` - Absolute tolerance for float comparison (used when `check_exact` is false)
277
/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
278
///
279
/// # Returns
280
///
281
/// * `Ok(())` if Series match according to specified criteria
282
/// * `Err` with details about mismatches if Series differ
283
///
284
/// # Behavior
285
///
286
/// 1. Handles categorical Series based on `categorical_as_str` flag
287
/// 2. Sorts Series if `check_order` is false
288
/// 3. For nested float types, delegates to `assert_series_nested_values_equal`
289
/// 4. For non-float types or when `check_exact` is true, requires exact match
290
/// 5. For float types with approximate matching:
291
/// - Verifies null values match using `assert_series_null_values_match`
292
/// - Verifies NaN values match using `assert_series_nan_values_match`
293
/// - Verifies float values are within tolerance using `assert_series_values_within_tolerance`
294
///
295
#[allow(clippy::too_many_arguments)]
296
fn assert_series_values_equal(
297
left: &Series,
298
right: &Series,
299
check_order: bool,
300
check_exact: bool,
301
check_dtypes: bool,
302
rel_tol: f64,
303
abs_tol: f64,
304
categorical_as_str: bool,
305
) -> PolarsResult<()> {
306
let (left, right) = if categorical_as_str {
307
(
308
categorical_series_to_string(left)?,
309
categorical_series_to_string(right)?,
310
)
311
} else {
312
(left.clone(), right.clone())
313
};
314
315
let (left, right) = if !check_order {
316
(
317
left.sort(SortOptions::default())?,
318
right.sort(SortOptions::default())?,
319
)
320
} else {
321
(left, right)
322
};
323
324
// When `check_dtypes` is `false` and both series are entirely null,
325
// consider them equal regardless of their underlying data types
326
if !check_dtypes && left.dtype() != right.dtype() {
327
if left.null_count() == left.len() && right.null_count() == right.len() {
328
return Ok(());
329
}
330
}
331
332
let unequal = match left.not_equal_missing(&right) {
333
Ok(result) => result,
334
Err(_) => {
335
return Err(polars_err!(
336
assertion_error = "Series",
337
"incompatible data types",
338
left.dtype(),
339
right.dtype()
340
));
341
},
342
};
343
344
if comparing_nested_floats(left.dtype(), right.dtype()) {
345
let filtered_left = left.filter(&unequal)?;
346
let filtered_right = right.filter(&unequal)?;
347
348
match assert_series_nested_values_equal(
349
&filtered_left,
350
&filtered_right,
351
check_exact,
352
check_dtypes,
353
rel_tol,
354
abs_tol,
355
categorical_as_str,
356
) {
357
Ok(_) => return Ok(()),
358
Err(_) => {
359
return Err(polars_err!(
360
assertion_error = "Series",
361
"nested value mismatch",
362
left,
363
right
364
));
365
},
366
}
367
}
368
369
if !unequal.any() {
370
return Ok(());
371
}
372
373
if check_exact || !left.dtype().is_float() || !right.dtype().is_float() {
374
return Err(polars_err!(
375
assertion_error = "Series",
376
"exact value mismatch",
377
left,
378
right
379
));
380
}
381
382
assert_series_null_values_match(&left, &right)?;
383
assert_series_nan_values_match(&left, &right)?;
384
assert_series_values_within_tolerance(&left, &right, &unequal, rel_tol, abs_tol)?;
385
386
Ok(())
387
}
388
389
/// Recursively compares nested Series structures (lists or structs) for equality.
390
///
391
/// This function handles the comparison of complex nested data structures by recursively
392
/// applying appropriate equality checks based on the nested data type.
393
///
394
/// # Arguments
395
///
396
/// * `left` - The first nested Series to compare
397
/// * `right` - The second nested Series to compare
398
/// * `check_exact` - If true, requires exact equality; if false, allows approximate equality for floats
399
/// * `rel_tol` - Relative tolerance for float comparison (used when `check_exact` is false)
400
/// * `abs_tol` - Absolute tolerance for float comparison (used when `check_exact` is false)
401
/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
402
///
403
/// # Returns
404
///
405
/// * `Ok(())` if nested Series match according to specified criteria
406
/// * `Err` with details about mismatches if Series differ
407
///
408
/// # Behavior
409
///
410
/// For List types:
411
/// 1. Iterates through corresponding elements in both Series
412
/// 2. Returns error if null values are encountered
413
/// 3. Creates single-element Series for each value and explodes them
414
/// 4. Recursively calls `assert_series_values_equal` on the exploded Series
415
///
416
/// For Struct types:
417
/// 1. Unnests both struct Series to access their columns
418
/// 2. Iterates through corresponding columns
419
/// 3. Recursively calls `assert_series_values_equal` on each column pair
420
///
421
fn assert_series_nested_values_equal(
422
left: &Series,
423
right: &Series,
424
check_exact: bool,
425
check_dtypes: bool,
426
rel_tol: f64,
427
abs_tol: f64,
428
categorical_as_str: bool,
429
) -> PolarsResult<()> {
430
if are_both_lists(left.dtype(), right.dtype()) {
431
let left_rechunked = left.rechunk();
432
let right_rechunked = right.rechunk();
433
434
let zipped = left_rechunked.iter().zip(right_rechunked.iter());
435
436
for (s1, s2) in zipped {
437
if s1.is_null() || s2.is_null() {
438
return Err(polars_err!(
439
assertion_error = "Series",
440
"nested value mismatch",
441
s1,
442
s2
443
));
444
} else {
445
let s1_series = Series::new("".into(), std::slice::from_ref(&s1));
446
let s2_series = Series::new("".into(), std::slice::from_ref(&s2));
447
448
match assert_series_values_equal(
449
&s1_series.explode(false)?,
450
&s2_series.explode(false)?,
451
true,
452
check_exact,
453
check_dtypes,
454
rel_tol,
455
abs_tol,
456
categorical_as_str,
457
) {
458
Ok(_) => continue,
459
Err(e) => return Err(e),
460
}
461
}
462
}
463
} else {
464
let ls = left.struct_()?.clone().unnest();
465
let rs = right.struct_()?.clone().unnest();
466
467
for col_name in ls.get_column_names() {
468
let s1_column = ls.column(col_name)?;
469
let s2_column = rs.column(col_name)?;
470
471
let s1_series = s1_column.as_materialized_series();
472
let s2_series = s2_column.as_materialized_series();
473
474
match assert_series_values_equal(
475
s1_series,
476
s2_series,
477
true,
478
check_exact,
479
check_dtypes,
480
rel_tol,
481
abs_tol,
482
categorical_as_str,
483
) {
484
Ok(_) => continue,
485
Err(e) => return Err(e),
486
}
487
}
488
}
489
490
Ok(())
491
}
492
493
/// Verifies that two Series are equal according to a set of configurable criteria.
494
///
495
/// This function serves as the main entry point for comparing Series, checking various
496
/// metadata properties before comparing the actual values.
497
///
498
/// # Arguments
499
///
500
/// * `left` - The first Series to compare
501
/// * `right` - The second Series to compare
502
/// * `options` - A `SeriesEqualOptions` struct containing configuration parameters:
503
/// * `check_names` - If true, verifies Series names match
504
/// * `check_dtypes` - If true, verifies data types match
505
/// * `check_order` - If true, elements must be in the same order
506
/// * `check_exact` - If true, requires exact equality for float values
507
/// * `rel_tol` - Relative tolerance for float comparison
508
/// * `abs_tol` - Absolute tolerance for float comparison
509
/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
510
///
511
/// # Returns
512
///
513
/// * `Ok(())` if Series match according to all specified criteria
514
/// * `Err` with details about the first mismatch encountered:
515
/// * Length mismatch
516
/// * Name mismatch (if checking names)
517
/// * Data type mismatch (if checking dtypes)
518
/// * Value mismatches (via `assert_series_values_equal`)
519
///
520
/// # Order of Checks
521
///
522
/// 1. Series length
523
/// 2. Series names (if `check_names` is true)
524
/// 3. Data types (if `check_dtypes` is true)
525
/// 4. Series values (delegated to `assert_series_values_equal`)
526
///
527
pub fn assert_series_equal(
528
left: &Series,
529
right: &Series,
530
options: SeriesEqualOptions,
531
) -> PolarsResult<()> {
532
// Short-circuit if they're the same series object
533
if std::ptr::eq(left, right) {
534
return Ok(());
535
}
536
537
if left.len() != right.len() {
538
return Err(polars_err!(
539
assertion_error = "Series",
540
"length mismatch",
541
left.len(),
542
right.len()
543
));
544
}
545
546
if options.check_names && left.name() != right.name() {
547
return Err(polars_err!(
548
assertion_error = "Series",
549
"name mismatch",
550
left.name(),
551
right.name()
552
));
553
}
554
555
if options.check_dtypes && left.dtype() != right.dtype() {
556
return Err(polars_err!(
557
assertion_error = "Series",
558
"dtype mismatch",
559
left.dtype(),
560
right.dtype()
561
));
562
}
563
564
assert_series_values_equal(
565
left,
566
right,
567
options.check_order,
568
options.check_exact,
569
options.check_dtypes,
570
options.rel_tol,
571
options.abs_tol,
572
options.categorical_as_str,
573
)
574
}
575
576
/// Configuration options for comparing DataFrame equality.
577
///
578
/// Controls the behavior of DataFrame equality comparisons by specifying
579
/// which aspects to check and the tolerance for floating point comparisons.
580
pub struct DataFrameEqualOptions {
581
/// Whether to check that rows appear in the same order.
582
pub check_row_order: bool,
583
/// Whether to check that columns appear in the same order.
584
pub check_column_order: bool,
585
/// Whether to check that the data types match for corresponding columns.
586
pub check_dtypes: bool,
587
/// Whether to check for exact equality (true) or approximate equality (false) for floating point values.
588
pub check_exact: bool,
589
/// Relative tolerance for approximate equality of floating point values.
590
pub rel_tol: f64,
591
/// Absolute tolerance for approximate equality of floating point values.
592
pub abs_tol: f64,
593
/// Whether to compare categorical values as strings.
594
pub categorical_as_str: bool,
595
}
596
597
impl Default for DataFrameEqualOptions {
598
/// Creates a new `DataFrameEqualOptions` with default settings.
599
///
600
/// Default configuration:
601
/// - Checks row order, column order, and data types
602
/// - Uses approximate equality comparisons for floating point values
603
/// - Sets relative tolerance to 1e-5 and absolute tolerance to 1e-8 for floating point comparisons
604
/// - Does not convert categorical values to strings for comparison
605
fn default() -> Self {
606
Self {
607
check_row_order: true,
608
check_column_order: true,
609
check_dtypes: true,
610
check_exact: false,
611
rel_tol: 1e-5,
612
abs_tol: 1e-8,
613
categorical_as_str: false,
614
}
615
}
616
}
617
618
impl DataFrameEqualOptions {
619
/// Creates a new `DataFrameEqualOptions` with default settings.
620
pub fn new() -> Self {
621
Self::default()
622
}
623
624
/// Sets whether to check that rows appear in the same order.
625
pub fn with_check_row_order(mut self, value: bool) -> Self {
626
self.check_row_order = value;
627
self
628
}
629
630
/// Sets whether to check that columns appear in the same order.
631
pub fn with_check_column_order(mut self, value: bool) -> Self {
632
self.check_column_order = value;
633
self
634
}
635
636
/// Sets whether to check that data types match for corresponding columns.
637
pub fn with_check_dtypes(mut self, value: bool) -> Self {
638
self.check_dtypes = value;
639
self
640
}
641
642
/// Sets whether to check for exact equality (true) or approximate equality (false) for floating point values.
643
pub fn with_check_exact(mut self, value: bool) -> Self {
644
self.check_exact = value;
645
self
646
}
647
648
/// Sets the relative tolerance for approximate equality of floating point values.
649
pub fn with_rel_tol(mut self, value: f64) -> Self {
650
self.rel_tol = value;
651
self
652
}
653
654
/// Sets the absolute tolerance for approximate equality of floating point values.
655
pub fn with_abs_tol(mut self, value: f64) -> Self {
656
self.abs_tol = value;
657
self
658
}
659
660
/// Sets whether to compare categorical values as strings.
661
pub fn with_categorical_as_str(mut self, value: bool) -> Self {
662
self.categorical_as_str = value;
663
self
664
}
665
}
666
667
/// Compares DataFrame schemas for equality based on specified criteria.
668
///
669
/// This function validates that two DataFrames have compatible schemas by checking
670
/// column names, their order, and optionally their data types according to the
671
/// provided configuration parameters.
672
///
673
/// # Arguments
674
///
675
/// * `left` - The first DataFrame to compare
676
/// * `right` - The second DataFrame to compare
677
/// * `check_dtypes` - If true, requires data types to match for corresponding columns
678
/// * `check_column_order` - If true, requires columns to appear in the same order
679
///
680
/// # Returns
681
///
682
/// * `Ok(())` if DataFrame schemas match according to specified criteria
683
/// * `Err` with details about schema mismatches if DataFrames differ
684
///
685
/// # Behavior
686
///
687
/// The function performs schema validation in the following order:
688
///
689
/// 1. **Fast path**: Returns immediately if schemas are identical
690
/// 2. **Column name validation**: Ensures both DataFrames have the same set of column names
691
/// - Reports columns present in left but missing in right
692
/// - Reports columns present in right but missing in left
693
/// 3. **Column order validation**: If `check_column_order` is true, verifies columns appear in the same sequence
694
/// 4. **Data type validation**: If `check_dtypes` is true, ensures corresponding columns have matching data types
695
/// - When `check_column_order` is false, compares data type sets for equality
696
/// - When `check_column_order` is true, performs more precise type checking
697
///
698
fn assert_dataframe_schema_equal(
699
left: &DataFrame,
700
right: &DataFrame,
701
check_dtypes: bool,
702
check_column_order: bool,
703
) -> PolarsResult<()> {
704
let left_schema = left.schema();
705
let right_schema = right.schema();
706
707
let ordered_left_cols = left.get_column_names();
708
let ordered_right_cols = right.get_column_names();
709
710
let left_set: PlHashSet<&PlSmallStr> = ordered_left_cols.iter().copied().collect();
711
let right_set: PlHashSet<&PlSmallStr> = ordered_right_cols.iter().copied().collect();
712
713
// Fast path for equal DataFrames
714
if left_schema == right_schema {
715
return Ok(());
716
}
717
718
if left_set != right_set {
719
let left_not_right: Vec<_> = left_set
720
.iter()
721
.filter(|col| !right_set.contains(*col))
722
.collect();
723
724
if !left_not_right.is_empty() {
725
return Err(polars_err!(
726
assertion_error = "DataFrames",
727
format!(
728
"columns mismatch: {:?} in left, but not in right",
729
left_not_right
730
),
731
format!("{:?}", left_set),
732
format!("{:?}", right_set)
733
));
734
} else {
735
let right_not_left: Vec<_> = right_set
736
.iter()
737
.filter(|col| !left_set.contains(*col))
738
.collect();
739
740
return Err(polars_err!(
741
assertion_error = "DataFrames",
742
format!(
743
"columns mismatch: {:?} in right, but not in left",
744
right_not_left
745
),
746
format!("{:?}", left_set),
747
format!("{:?}", right_set)
748
));
749
}
750
}
751
752
if check_column_order && ordered_left_cols != ordered_right_cols {
753
return Err(polars_err!(
754
assertion_error = "DataFrames",
755
"columns are not in the same order",
756
format!("{:?}", ordered_left_cols),
757
format!("{:?}", ordered_right_cols)
758
));
759
}
760
761
if check_dtypes {
762
if check_column_order {
763
let left_dtypes_ordered = left.dtypes();
764
let right_dtypes_ordered = right.dtypes();
765
if left_dtypes_ordered != right_dtypes_ordered {
766
return Err(polars_err!(
767
assertion_error = "DataFrames",
768
"dtypes do not match",
769
format!("{:?}", left_dtypes_ordered),
770
format!("{:?}", right_dtypes_ordered)
771
));
772
}
773
} else {
774
let left_dtypes: PlHashSet<DataType> = left.dtypes().into_iter().collect();
775
let right_dtypes: PlHashSet<DataType> = right.dtypes().into_iter().collect();
776
if left_dtypes != right_dtypes {
777
return Err(polars_err!(
778
assertion_error = "DataFrames",
779
"dtypes do not match",
780
format!("{:?}", left_dtypes),
781
format!("{:?}", right_dtypes)
782
));
783
}
784
}
785
}
786
787
Ok(())
788
}
789
790
/// Verifies that two DataFrames are equal according to a set of configurable criteria.
791
///
792
/// This function serves as the main entry point for comparing DataFrames, first validating
793
/// schema compatibility and then comparing the actual data values column by column.
794
///
795
/// # Arguments
796
///
797
/// * `left` - The first DataFrame to compare
798
/// * `right` - The second DataFrame to compare
799
/// * `options` - A `DataFrameEqualOptions` struct containing configuration parameters:
800
/// * `check_row_order` - If true, rows must be in the same order
801
/// * `check_column_order` - If true, columns must be in the same order
802
/// * `check_dtypes` - If true, verifies data types match for corresponding columns
803
/// * `check_exact` - If true, requires exact equality for float values
804
/// * `rel_tol` - Relative tolerance for float comparison
805
/// * `abs_tol` - Absolute tolerance for float comparison
806
/// * `categorical_as_str` - If true, converts categorical values to strings before comparison
807
///
808
/// # Returns
809
///
810
/// * `Ok(())` if DataFrames match according to all specified criteria
811
/// * `Err` with details about the first mismatch encountered:
812
/// * Schema mismatches (column names, order, or data types)
813
/// * Height (row count) mismatch
814
/// * Value mismatches in specific columns
815
///
816
/// # Order of Checks
817
///
818
/// 1. Schema validation (column names, order, and data types via `assert_dataframe_schema_equal`)
819
/// 2. DataFrame height (row count)
820
/// 3. Row ordering (sorts both DataFrames if `check_row_order` is false)
821
/// 4. Column-by-column value comparison (delegated to `assert_series_values_equal`)
822
///
823
/// # Behavior
824
///
825
/// When `check_row_order` is false, both DataFrames are sorted using all columns to ensure
826
/// consistent ordering before value comparison. This allows for row-order-independent equality
827
/// checking while maintaining deterministic results.
828
///
829
pub fn assert_dataframe_equal(
830
left: &DataFrame,
831
right: &DataFrame,
832
options: DataFrameEqualOptions,
833
) -> PolarsResult<()> {
834
// Short-circuit if they're the same DataFrame object
835
if std::ptr::eq(left, right) {
836
return Ok(());
837
}
838
839
assert_dataframe_schema_equal(
840
left,
841
right,
842
options.check_dtypes,
843
options.check_column_order,
844
)?;
845
846
if left.height() != right.height() {
847
return Err(polars_err!(
848
assertion_error = "DataFrames",
849
"height (row count) mismatch",
850
left.height(),
851
right.height()
852
));
853
}
854
855
let left_cols = left.get_column_names_owned();
856
857
let (left, right) = if !options.check_row_order {
858
(
859
left.sort(left_cols.clone(), SortMultipleOptions::default())?,
860
right.sort(left_cols.clone(), SortMultipleOptions::default())?,
861
)
862
} else {
863
(left.clone(), right.clone())
864
};
865
866
for col in left_cols.iter() {
867
let s_left = left.column(col)?;
868
let s_right = right.column(col)?;
869
870
let s_left_series = s_left.as_materialized_series();
871
let s_right_series = s_right.as_materialized_series();
872
873
match assert_series_values_equal(
874
s_left_series,
875
s_right_series,
876
true,
877
options.check_exact,
878
options.check_dtypes,
879
options.rel_tol,
880
options.abs_tol,
881
options.categorical_as_str,
882
) {
883
Ok(_) => {},
884
Err(_) => {
885
return Err(polars_err!(
886
assertion_error = "DataFrames",
887
format!("value mismatch for column {:?}", col),
888
format!("{:?}", s_left_series),
889
format!("{:?}", s_right_series)
890
));
891
},
892
}
893
}
894
895
Ok(())
896
}
897
898