Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-testing/src/asserts/utils.rs
8395 views
1
use std::ops::Not;
2
3
use polars_core::datatypes::unpack_dtypes;
4
use polars_core::prelude::*;
5
use polars_ops::series::is_close;
6
7
/// Configuration options for comparing Series equality.
8
///
9
/// Controls the behavior of Series equality comparisons by specifying
10
/// which aspects to check and the tolerance for floating point comparisons.
11
pub struct SeriesEqualOptions {
12
/// Whether to check that the data types match.
13
pub check_dtypes: bool,
14
/// Whether to check that the Series names match.
15
pub check_names: bool,
16
/// Whether to check that elements appear in the same order.
17
pub check_order: bool,
18
/// Whether to check for exact equality (true) or approximate equality (false) for floating point values.
19
pub check_exact: bool,
20
/// Relative tolerance for approximate equality of floating point values.
21
pub rel_tol: f64,
22
/// Absolute tolerance for approximate equality of floating point values.
23
pub abs_tol: f64,
24
/// Whether to compare categorical values as strings.
25
pub categorical_as_str: bool,
26
}
27
28
impl Default for SeriesEqualOptions {
29
/// Creates a new `SeriesEqualOptions` with default settings.
30
///
31
/// Default configuration:
32
/// - Checks data types, names, and order
33
/// - Uses exact equality comparisons
34
/// - Sets relative tolerance to 1e-5 and absolute tolerance to 1e-8 for floating point comparisons
35
/// - Does not convert categorical values to strings for comparison
36
fn default() -> Self {
37
Self {
38
check_dtypes: true,
39
check_names: true,
40
check_order: true,
41
check_exact: true,
42
rel_tol: 1e-5,
43
abs_tol: 1e-8,
44
categorical_as_str: false,
45
}
46
}
47
}
48
49
impl SeriesEqualOptions {
50
/// Creates a new `SeriesEqualOptions` with default settings.
51
pub fn new() -> Self {
52
Self::default()
53
}
54
55
/// Sets whether to check that data types match.
56
pub fn with_check_dtypes(mut self, value: bool) -> Self {
57
self.check_dtypes = value;
58
self
59
}
60
61
/// Sets whether to check that Series names match.
62
pub fn with_check_names(mut self, value: bool) -> Self {
63
self.check_names = value;
64
self
65
}
66
67
/// Sets whether to check that elements appear in the same order.
68
pub fn with_check_order(mut self, value: bool) -> Self {
69
self.check_order = value;
70
self
71
}
72
73
/// Sets whether to check for exact equality (true) or approximate equality (false) for floating point values.
74
pub fn with_check_exact(mut self, value: bool) -> Self {
75
self.check_exact = value;
76
self
77
}
78
79
/// Sets the relative tolerance for approximate equality of floating point values.
80
pub fn with_rel_tol(mut self, value: f64) -> Self {
81
self.rel_tol = value;
82
self
83
}
84
85
/// Sets the absolute tolerance for approximate equality of floating point values.
86
pub fn with_abs_tol(mut self, value: f64) -> Self {
87
self.abs_tol = value;
88
self
89
}
90
91
/// Sets whether to compare categorical values as strings.
92
pub fn with_categorical_as_str(mut self, value: bool) -> Self {
93
self.categorical_as_str = value;
94
self
95
}
96
}
97
98
/// Change a (possibly nested) Categorical data type to a String data type.
99
fn categorical_dtype_to_string_dtype(dtype: &DataType) -> DataType {
100
match dtype {
101
DataType::Categorical(..) => DataType::String,
102
DataType::List(inner) => {
103
let inner_cast = categorical_dtype_to_string_dtype(inner);
104
DataType::List(Box::new(inner_cast))
105
},
106
DataType::Array(inner, size) => {
107
let inner_cast = categorical_dtype_to_string_dtype(inner);
108
DataType::Array(Box::new(inner_cast), *size)
109
},
110
DataType::Struct(fields) => {
111
let transformed_fields = fields
112
.iter()
113
.map(|field| {
114
Field::new(
115
field.name().clone(),
116
categorical_dtype_to_string_dtype(field.dtype()),
117
)
118
})
119
.collect::<Vec<Field>>();
120
121
DataType::Struct(transformed_fields)
122
},
123
_ => dtype.clone(),
124
}
125
}
126
127
/// Cast a (possibly nested) Categorical Series to a String Series.
128
fn categorical_series_to_string(s: &Series) -> PolarsResult<Series> {
129
let dtype = s.dtype();
130
let noncat_dtype = categorical_dtype_to_string_dtype(dtype);
131
132
if *dtype != noncat_dtype {
133
Ok(s.cast(&noncat_dtype)?)
134
} else {
135
Ok(s.clone())
136
}
137
}
138
139
/// Returns true if both DataTypes are floating point types.
140
fn are_both_floats(left: &DataType, right: &DataType) -> bool {
141
left.is_float() && right.is_float()
142
}
143
144
/// Returns true if both DataTypes are list-like (either List or Array types).
145
fn are_both_lists(left: &DataType, right: &DataType) -> bool {
146
matches!(left, DataType::List(_) | DataType::Array(_, _))
147
&& matches!(right, DataType::List(_) | DataType::Array(_, _))
148
}
149
150
/// Returns true if both DataTypes are struct types.
151
fn are_both_structs(left: &DataType, right: &DataType) -> bool {
152
left.is_struct() && right.is_struct()
153
}
154
155
/// Returns true if both DataTypes are nested types (lists or structs) that contain floating point types within them.
156
/// First checks if both types are either lists or structs, then unpacks their nested DataTypes to determine if
157
/// at least one floating point type exists in each of the nested structures.
158
fn comparing_nested_floats(left: &DataType, right: &DataType) -> bool {
159
if !are_both_lists(left, right) && !are_both_structs(left, right) {
160
return false;
161
}
162
163
let left_dtypes = unpack_dtypes(left, false);
164
let right_dtypes = unpack_dtypes(right, false);
165
166
let left_has_floats = left_dtypes.iter().any(|dt| dt.is_float());
167
let right_has_floats = right_dtypes.iter().any(|dt| dt.is_float());
168
169
left_has_floats && right_has_floats
170
}
171
172
/// Ensures that null values in two Series match exactly and returns an error if any mismatches are found.
173
fn assert_series_null_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
174
let null_value_mismatch = left.is_null().not_equal(&right.is_null());
175
176
if null_value_mismatch.any() {
177
return Err(polars_err!(
178
assertion_error = "Series",
179
"null value mismatch",
180
left.null_count(),
181
right.null_count()
182
));
183
}
184
185
Ok(())
186
}
187
188
/// Validates that NaN patterns are identical between two float Series, returning error if any mismatches are found.
189
fn assert_series_nan_values_match(left: &Series, right: &Series) -> PolarsResult<()> {
190
if !are_both_floats(left.dtype(), right.dtype()) {
191
return Ok(());
192
}
193
let left_nan = left.is_nan()?;
194
let right_nan = right.is_nan()?;
195
196
let nan_value_mismatch = left_nan.not_equal(&right_nan);
197
198
let left_nan_count = left_nan.sum().unwrap_or(0);
199
let right_nan_count = right_nan.sum().unwrap_or(0);
200
201
if nan_value_mismatch.any() {
202
return Err(polars_err!(
203
assertion_error = "Series",
204
"nan value mismatch",
205
left_nan_count,
206
right_nan_count
207
));
208
}
209
210
Ok(())
211
}
212
213
/// Verifies that two Series have values within a specified tolerance.
214
///
215
/// This function checks if the values in `left` and `right` Series that are marked as unequal
216
/// in the `unequal` boolean array are within the specified relative and absolute tolerances.
217
///
218
/// # Arguments
219
///
220
/// * `left` - The first Series to compare
221
/// * `right` - The second Series to compare
222
/// * `unequal` - Boolean ChunkedArray indicating which elements to check (true = check this element)
223
/// * `rel_tol` - Relative tolerance (relative to the maximum absolute value of the two Series)
224
/// * `abs_tol` - Absolute tolerance added to the relative tolerance
225
///
226
/// # Returns
227
///
228
/// * `Ok(())` if all values are within tolerance
229
/// * `Err` with details about problematic values if any values exceed the tolerance
230
///
231
/// # Formula
232
///
233
/// Values are considered within tolerance if:
234
/// `|left - right| <= max(rel_tol * max(abs(left), abs(right)), abs_tol)` OR values are exactly equal
235
///
236
fn assert_series_values_within_tolerance(
237
left: &Series,
238
right: &Series,
239
unequal: &ChunkedArray<BooleanType>,
240
rel_tol: f64,
241
abs_tol: f64,
242
) -> PolarsResult<()> {
243
let left_unequal = left.filter(unequal)?;
244
let right_unequal = right.filter(unequal)?;
245
246
let within_tolerance = is_close(&left_unequal, &right_unequal, abs_tol, rel_tol, false)?;
247
if within_tolerance.all() {
248
Ok(())
249
} else {
250
let exceeded_indices = within_tolerance.not();
251
let problematic_left = left_unequal.filter(&exceeded_indices)?;
252
let problematic_right = right_unequal.filter(&exceeded_indices)?;
253
254
Err(polars_err!(
255
assertion_error = "Series",
256
"values not within tolerance",
257
problematic_left,
258
problematic_right
259
))
260
}
261
}
262
263
/// Compares two Series for equality with configurable options for ordering, exact matching, and tolerance.
264
///
265
/// This function verifies that the values in `left` and `right` Series are equal according to
266
/// the specified comparison criteria. It handles different types including floats and nested types
267
/// with appropriate equality checks.
268
///
269
/// # Arguments
270
///
271
/// * `left` - The first Series to compare
272
/// * `right` - The second Series to compare
273
/// * `check_order` - If true, elements must be in the same order; if false, Series will be sorted before comparison
274
/// * `check_exact` - If true, requires exact equality; if false, allows approximate equality for floats within tolerance
275
/// * `rel_tol` - Relative tolerance for float comparison (used when `check_exact` is false)
276
/// * `abs_tol` - Absolute tolerance for float comparison (used when `check_exact` is false)
277
/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
278
///
279
/// # Returns
280
///
281
/// * `Ok(())` if Series match according to specified criteria
282
/// * `Err` with details about mismatches if Series differ
283
///
284
/// # Behavior
285
///
286
/// 1. Handles categorical Series based on `categorical_as_str` flag
287
/// 2. Sorts Series if `check_order` is false
288
/// 3. For nested float types, delegates to `assert_series_nested_values_equal`
289
/// 4. For non-float types or when `check_exact` is true, requires exact match
290
/// 5. For float types with approximate matching:
291
/// - Verifies null values match using `assert_series_null_values_match`
292
/// - Verifies NaN values match using `assert_series_nan_values_match`
293
/// - Verifies float values are within tolerance using `assert_series_values_within_tolerance`
294
///
295
#[allow(clippy::too_many_arguments)]
296
fn assert_series_values_equal(
297
left: &Series,
298
right: &Series,
299
check_order: bool,
300
check_exact: bool,
301
check_dtypes: bool,
302
rel_tol: f64,
303
abs_tol: f64,
304
categorical_as_str: bool,
305
) -> PolarsResult<()> {
306
// When `check_dtypes` is `false` and both series are entirely null,
307
// consider them equal regardless of their underlying data types
308
if !check_dtypes && left.dtype() != right.dtype() {
309
if left.null_count() == left.len() && right.null_count() == right.len() {
310
return Ok(());
311
}
312
}
313
314
let (left, right) = if categorical_as_str {
315
(
316
categorical_series_to_string(left)?,
317
categorical_series_to_string(right)?,
318
)
319
} else {
320
(left.clone(), right.clone())
321
};
322
323
let (left, right) = if !check_order {
324
(
325
left.sort(SortOptions::default())?,
326
right.sort(SortOptions::default())?,
327
)
328
} else {
329
(left, right)
330
};
331
332
let unequal = match left.not_equal_missing(&right) {
333
Ok(result) => result,
334
Err(_) => {
335
return Err(polars_err!(
336
assertion_error = "Series",
337
"incompatible data types",
338
left.dtype(),
339
right.dtype()
340
));
341
},
342
};
343
344
if comparing_nested_floats(left.dtype(), right.dtype()) {
345
let filtered_left = left.filter(&unequal)?;
346
let filtered_right = right.filter(&unequal)?;
347
348
match assert_series_nested_values_equal(
349
&filtered_left,
350
&filtered_right,
351
check_exact,
352
check_dtypes,
353
rel_tol,
354
abs_tol,
355
categorical_as_str,
356
) {
357
Ok(_) => return Ok(()),
358
Err(_) => {
359
return Err(polars_err!(
360
assertion_error = "Series",
361
"nested value mismatch",
362
left,
363
right
364
));
365
},
366
}
367
}
368
369
if !unequal.any() {
370
return Ok(());
371
}
372
373
if check_exact || !left.dtype().is_float() || !right.dtype().is_float() {
374
return Err(polars_err!(
375
assertion_error = "Series",
376
"exact value mismatch",
377
left,
378
right
379
));
380
}
381
382
assert_series_null_values_match(&left, &right)?;
383
assert_series_nan_values_match(&left, &right)?;
384
assert_series_values_within_tolerance(&left, &right, &unequal, rel_tol, abs_tol)?;
385
386
Ok(())
387
}
388
389
/// Recursively compares nested Series structures (lists or structs) for equality.
390
///
391
/// This function handles the comparison of complex nested data structures by recursively
392
/// applying appropriate equality checks based on the nested data type.
393
///
394
/// # Arguments
395
///
396
/// * `left` - The first nested Series to compare
397
/// * `right` - The second nested Series to compare
398
/// * `check_exact` - If true, requires exact equality; if false, allows approximate equality for floats
399
/// * `rel_tol` - Relative tolerance for float comparison (used when `check_exact` is false)
400
/// * `abs_tol` - Absolute tolerance for float comparison (used when `check_exact` is false)
401
/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
402
///
403
/// # Returns
404
///
405
/// * `Ok(())` if nested Series match according to specified criteria
406
/// * `Err` with details about mismatches if Series differ
407
///
408
/// # Behavior
409
///
410
/// For List types:
411
/// 1. Iterates through corresponding elements in both Series
412
/// 2. Returns error if null values are encountered
413
/// 3. Creates single-element Series for each value and explodes them
414
/// 4. Recursively calls `assert_series_values_equal` on the exploded Series
415
///
416
/// For Struct types:
417
/// 1. Unnests both struct Series to access their columns
418
/// 2. Iterates through corresponding columns
419
/// 3. Recursively calls `assert_series_values_equal` on each column pair
420
///
421
fn assert_series_nested_values_equal(
422
left: &Series,
423
right: &Series,
424
check_exact: bool,
425
check_dtypes: bool,
426
rel_tol: f64,
427
abs_tol: f64,
428
categorical_as_str: bool,
429
) -> PolarsResult<()> {
430
if are_both_lists(left.dtype(), right.dtype()) {
431
let zipped = left.iter().zip(right.iter());
432
433
for (s1, s2) in zipped {
434
if s1.is_null() || s2.is_null() {
435
return Err(polars_err!(
436
assertion_error = "Series",
437
"nested value mismatch",
438
s1,
439
s2
440
));
441
} else {
442
let s1_series = Series::new("".into(), std::slice::from_ref(&s1));
443
let s2_series = Series::new("".into(), std::slice::from_ref(&s2));
444
445
assert_series_values_equal(
446
&s1_series.explode(ExplodeOptions {
447
empty_as_null: true,
448
keep_nulls: true,
449
})?,
450
&s2_series.explode(ExplodeOptions {
451
empty_as_null: true,
452
keep_nulls: true,
453
})?,
454
true,
455
check_exact,
456
check_dtypes,
457
rel_tol,
458
abs_tol,
459
categorical_as_str,
460
)?
461
}
462
}
463
} else {
464
let ls = left.struct_()?.clone().unnest();
465
let rs = right.struct_()?.clone().unnest();
466
467
for col_name in ls.get_column_names() {
468
let s1_column = ls.column(col_name)?;
469
let s2_column = rs.column(col_name)?;
470
471
let s1_series = s1_column.as_materialized_series();
472
let s2_series = s2_column.as_materialized_series();
473
474
assert_series_values_equal(
475
s1_series,
476
s2_series,
477
true,
478
check_exact,
479
check_dtypes,
480
rel_tol,
481
abs_tol,
482
categorical_as_str,
483
)?
484
}
485
}
486
487
Ok(())
488
}
489
490
/// Verifies that two Series are equal according to a set of configurable criteria.
491
///
492
/// This function serves as the main entry point for comparing Series, checking various
493
/// metadata properties before comparing the actual values.
494
///
495
/// # Arguments
496
///
497
/// * `left` - The first Series to compare
498
/// * `right` - The second Series to compare
499
/// * `options` - A `SeriesEqualOptions` struct containing configuration parameters:
500
/// * `check_names` - If true, verifies Series names match
501
/// * `check_dtypes` - If true, verifies data types match
502
/// * `check_order` - If true, elements must be in the same order
503
/// * `check_exact` - If true, requires exact equality for float values
504
/// * `rel_tol` - Relative tolerance for float comparison
505
/// * `abs_tol` - Absolute tolerance for float comparison
506
/// * `categorical_as_str` - If true, converts categorical Series to strings before comparison
507
///
508
/// # Returns
509
///
510
/// * `Ok(())` if Series match according to all specified criteria
511
/// * `Err` with details about the first mismatch encountered:
512
/// * Length mismatch
513
/// * Name mismatch (if checking names)
514
/// * Data type mismatch (if checking dtypes)
515
/// * Value mismatches (via `assert_series_values_equal`)
516
///
517
/// # Order of Checks
518
///
519
/// 1. Series length
520
/// 2. Series names (if `check_names` is true)
521
/// 3. Data types (if `check_dtypes` is true)
522
/// 4. Series values (delegated to `assert_series_values_equal`)
523
///
524
pub fn assert_series_equal(
525
left: &Series,
526
right: &Series,
527
options: SeriesEqualOptions,
528
) -> PolarsResult<()> {
529
// Short-circuit if they're the same series object
530
if std::ptr::eq(left, right) {
531
return Ok(());
532
}
533
534
if left.len() != right.len() {
535
return Err(polars_err!(
536
assertion_error = "Series",
537
"length mismatch",
538
left.len(),
539
right.len()
540
));
541
}
542
543
if options.check_names && left.name() != right.name() {
544
return Err(polars_err!(
545
assertion_error = "Series",
546
"name mismatch",
547
left.name(),
548
right.name()
549
));
550
}
551
552
if options.check_dtypes && left.dtype() != right.dtype() {
553
return Err(polars_err!(
554
assertion_error = "Series",
555
"dtype mismatch",
556
left.dtype(),
557
right.dtype()
558
));
559
}
560
561
assert_series_values_equal(
562
left,
563
right,
564
options.check_order,
565
options.check_exact,
566
options.check_dtypes,
567
options.rel_tol,
568
options.abs_tol,
569
options.categorical_as_str,
570
)
571
}
572
573
/// Configuration options for comparing DataFrame equality.
574
///
575
/// Controls the behavior of DataFrame equality comparisons by specifying
576
/// which aspects to check and the tolerance for floating point comparisons.
577
pub struct DataFrameEqualOptions {
578
/// Whether to check that rows appear in the same order.
579
pub check_row_order: bool,
580
/// Whether to check that columns appear in the same order.
581
pub check_column_order: bool,
582
/// Whether to check that the data types match for corresponding columns.
583
pub check_dtypes: bool,
584
/// Whether to check for exact equality (true) or approximate equality (false) for floating point values.
585
pub check_exact: bool,
586
/// Relative tolerance for approximate equality of floating point values.
587
pub rel_tol: f64,
588
/// Absolute tolerance for approximate equality of floating point values.
589
pub abs_tol: f64,
590
/// Whether to compare categorical values as strings.
591
pub categorical_as_str: bool,
592
}
593
594
impl Default for DataFrameEqualOptions {
595
/// Creates a new `DataFrameEqualOptions` with default settings.
596
///
597
/// Default configuration:
598
/// - Checks row order, column order, and data types
599
/// - Uses approximate equality comparisons for floating point values
600
/// - Sets relative tolerance to 1e-5 and absolute tolerance to 1e-8 for floating point comparisons
601
/// - Does not convert categorical values to strings for comparison
602
fn default() -> Self {
603
Self {
604
check_row_order: true,
605
check_column_order: true,
606
check_dtypes: true,
607
check_exact: false,
608
rel_tol: 1e-5,
609
abs_tol: 1e-8,
610
categorical_as_str: false,
611
}
612
}
613
}
614
615
impl DataFrameEqualOptions {
616
/// Creates a new `DataFrameEqualOptions` with default settings.
617
pub fn new() -> Self {
618
Self::default()
619
}
620
621
/// Sets whether to check that rows appear in the same order.
622
pub fn with_check_row_order(mut self, value: bool) -> Self {
623
self.check_row_order = value;
624
self
625
}
626
627
/// Sets whether to check that columns appear in the same order.
628
pub fn with_check_column_order(mut self, value: bool) -> Self {
629
self.check_column_order = value;
630
self
631
}
632
633
/// Sets whether to check that data types match for corresponding columns.
634
pub fn with_check_dtypes(mut self, value: bool) -> Self {
635
self.check_dtypes = value;
636
self
637
}
638
639
/// Sets whether to check for exact equality (true) or approximate equality (false) for floating point values.
640
pub fn with_check_exact(mut self, value: bool) -> Self {
641
self.check_exact = value;
642
self
643
}
644
645
/// Sets the relative tolerance for approximate equality of floating point values.
646
pub fn with_rel_tol(mut self, value: f64) -> Self {
647
self.rel_tol = value;
648
self
649
}
650
651
/// Sets the absolute tolerance for approximate equality of floating point values.
652
pub fn with_abs_tol(mut self, value: f64) -> Self {
653
self.abs_tol = value;
654
self
655
}
656
657
/// Sets whether to compare categorical values as strings.
658
pub fn with_categorical_as_str(mut self, value: bool) -> Self {
659
self.categorical_as_str = value;
660
self
661
}
662
}
663
664
/// Compares DataFrame schemas for equality based on specified criteria.
665
///
666
/// This function validates that two DataFrames have compatible schemas by checking
667
/// column names, their order, and optionally their data types according to the
668
/// provided configuration parameters.
669
///
670
/// # Arguments
671
///
672
/// * `left` - The first DataFrame to compare
673
/// * `right` - The second DataFrame to compare
674
/// * `check_dtypes` - If true, requires data types to match for corresponding columns
675
/// * `check_column_order` - If true, requires columns to appear in the same order
676
///
677
/// # Returns
678
///
679
/// * `Ok(())` if DataFrame schemas match according to specified criteria
680
/// * `Err` with details about schema mismatches if DataFrames differ
681
///
682
/// # Behavior
683
///
684
/// The function performs schema validation in the following order:
685
///
686
/// 1. **Fast path**: Returns immediately if schemas are identical
687
/// 2. **Column name validation**: Ensures both DataFrames have the same set of column names
688
/// - Reports columns present in left but missing in right
689
/// - Reports columns present in right but missing in left
690
/// 3. **Column order validation**: If `check_column_order` is true, verifies columns appear in the same sequence
691
/// 4. **Data type validation**: If `check_dtypes` is true, ensures corresponding columns have matching data types
692
/// - When `check_column_order` is false, compares data type sets for equality
693
/// - When `check_column_order` is true, performs more precise type checking
694
///
695
fn assert_dataframe_schema_equal(
696
left: &DataFrame,
697
right: &DataFrame,
698
check_dtypes: bool,
699
check_column_order: bool,
700
) -> PolarsResult<()> {
701
let left_schema = left.schema();
702
let right_schema = right.schema();
703
704
let ordered_left_cols = left.get_column_names();
705
let ordered_right_cols = right.get_column_names();
706
707
let left_set: PlHashSet<&PlSmallStr> = ordered_left_cols.iter().copied().collect();
708
let right_set: PlHashSet<&PlSmallStr> = ordered_right_cols.iter().copied().collect();
709
710
// Fast path for equal DataFrames
711
if left_schema == right_schema {
712
return Ok(());
713
}
714
715
if left_set != right_set {
716
let left_not_right: Vec<_> = left_set
717
.iter()
718
.filter(|col| !right_set.contains(*col))
719
.collect();
720
721
if !left_not_right.is_empty() {
722
return Err(polars_err!(
723
assertion_error = "DataFrames",
724
format!(
725
"columns mismatch: {:?} in left, but not in right",
726
left_not_right
727
),
728
format!("{:?}", left_set),
729
format!("{:?}", right_set)
730
));
731
} else {
732
let right_not_left: Vec<_> = right_set
733
.iter()
734
.filter(|col| !left_set.contains(*col))
735
.collect();
736
737
return Err(polars_err!(
738
assertion_error = "DataFrames",
739
format!(
740
"columns mismatch: {:?} in right, but not in left",
741
right_not_left
742
),
743
format!("{:?}", left_set),
744
format!("{:?}", right_set)
745
));
746
}
747
}
748
749
if check_column_order && ordered_left_cols != ordered_right_cols {
750
return Err(polars_err!(
751
assertion_error = "DataFrames",
752
"columns are not in the same order",
753
format!("{:?}", ordered_left_cols),
754
format!("{:?}", ordered_right_cols)
755
));
756
}
757
758
if check_dtypes {
759
if check_column_order {
760
let left_dtypes_ordered = left.dtypes();
761
let right_dtypes_ordered = right.dtypes();
762
if left_dtypes_ordered != right_dtypes_ordered {
763
return Err(polars_err!(
764
assertion_error = "DataFrames",
765
"dtypes do not match",
766
format!("{:?}", left_dtypes_ordered),
767
format!("{:?}", right_dtypes_ordered)
768
));
769
}
770
} else {
771
let left_dtypes: PlHashSet<DataType> = left.dtypes().into_iter().collect();
772
let right_dtypes: PlHashSet<DataType> = right.dtypes().into_iter().collect();
773
if left_dtypes != right_dtypes {
774
return Err(polars_err!(
775
assertion_error = "DataFrames",
776
"dtypes do not match",
777
format!("{:?}", left_dtypes),
778
format!("{:?}", right_dtypes)
779
));
780
}
781
}
782
}
783
784
Ok(())
785
}
786
787
/// Verifies that two DataFrames are equal according to a set of configurable criteria.
788
///
789
/// This function serves as the main entry point for comparing DataFrames, first validating
790
/// schema compatibility and then comparing the actual data values column by column.
791
///
792
/// # Arguments
793
///
794
/// * `left` - The first DataFrame to compare
795
/// * `right` - The second DataFrame to compare
796
/// * `options` - A `DataFrameEqualOptions` struct containing configuration parameters:
797
/// * `check_row_order` - If true, rows must be in the same order
798
/// * `check_column_order` - If true, columns must be in the same order
799
/// * `check_dtypes` - If true, verifies data types match for corresponding columns
800
/// * `check_exact` - If true, requires exact equality for float values
801
/// * `rel_tol` - Relative tolerance for float comparison
802
/// * `abs_tol` - Absolute tolerance for float comparison
803
/// * `categorical_as_str` - If true, converts categorical values to strings before comparison
804
///
805
/// # Returns
806
///
807
/// * `Ok(())` if DataFrames match according to all specified criteria
808
/// * `Err` with details about the first mismatch encountered:
809
/// * Schema mismatches (column names, order, or data types)
810
/// * Height (row count) mismatch
811
/// * Value mismatches in specific columns
812
///
813
/// # Order of Checks
814
///
815
/// 1. Schema validation (column names, order, and data types via `assert_dataframe_schema_equal`)
816
/// 2. DataFrame height (row count)
817
/// 3. Row ordering (sorts both DataFrames if `check_row_order` is false)
818
/// 4. Column-by-column value comparison (delegated to `assert_series_values_equal`)
819
///
820
/// # Behavior
821
///
822
/// When `check_row_order` is false, both DataFrames are sorted using all columns to ensure
823
/// consistent ordering before value comparison. This allows for row-order-independent equality
824
/// checking while maintaining deterministic results.
825
///
826
pub fn assert_dataframe_equal(
827
left: &DataFrame,
828
right: &DataFrame,
829
options: DataFrameEqualOptions,
830
) -> PolarsResult<()> {
831
// Short-circuit if they're the same DataFrame object
832
if std::ptr::eq(left, right) {
833
return Ok(());
834
}
835
836
assert_dataframe_schema_equal(
837
left,
838
right,
839
options.check_dtypes,
840
options.check_column_order,
841
)?;
842
843
if left.height() != right.height() {
844
return Err(polars_err!(
845
assertion_error = "DataFrames",
846
"height (row count) mismatch",
847
left.height(),
848
right.height()
849
));
850
}
851
852
let left_cols = left.get_column_names_owned();
853
854
let (left, right) = if !options.check_row_order {
855
(
856
left.sort(left_cols.clone(), SortMultipleOptions::default())?,
857
right.sort(left_cols.clone(), SortMultipleOptions::default())?,
858
)
859
} else {
860
(left.clone(), right.clone())
861
};
862
863
for col in left_cols.iter() {
864
let s_left = left.column(col)?;
865
let s_right = right.column(col)?;
866
867
let s_left_series = s_left.as_materialized_series();
868
let s_right_series = s_right.as_materialized_series();
869
870
match assert_series_values_equal(
871
s_left_series,
872
s_right_series,
873
true,
874
options.check_exact,
875
options.check_dtypes,
876
options.rel_tol,
877
options.abs_tol,
878
options.categorical_as_str,
879
) {
880
Ok(_) => {},
881
Err(_) => {
882
return Err(polars_err!(
883
assertion_error = "DataFrames",
884
format!("value mismatch for column {:?}", col),
885
format!("{:?}", s_left_series),
886
format!("{:?}", s_right_series)
887
));
888
},
889
}
890
}
891
892
Ok(())
893
}
894
895