Path: blob/main/py-polars/tests/unit/series/test_zip_with.py
6939 views
import pytest12import polars as pl3from polars.testing import assert_series_equal456def test_zip_with_all_true_mask() -> None:7s1 = pl.Series([1, 2, 3])8s2 = pl.Series([4, 5, 6])9mask = pl.Series([True, True, True])1011result = s1.zip_with(mask, s2)12assert_series_equal(result, s1)131415def test_zip_with_all_false_mask() -> None:16s1 = pl.Series([1, 2, 3])17s2 = pl.Series([4, 5, 6])18mask = pl.Series([False, False, False])1920result = s1.zip_with(mask, s2)21assert_series_equal(result, s2)222324def test_zip_with_mixed_mask() -> None:25s1 = pl.Series([1, 2, 3, 4, 5])26s2 = pl.Series([5, 4, 3, 2, 1])27mask = pl.Series([True, False, True, False, True])2829result = s1.zip_with(mask, s2)30expected = pl.Series([1, 4, 3, 2, 5])31assert_series_equal(result, expected)323334def test_zip_with_series_comparison() -> None:35s1 = pl.Series([1, 2, 3, 4, 5])36s2 = pl.Series([5, 4, 3, 2, 1])3738result = s1.zip_with(s1 < s2, s2)39expected = pl.Series([1, 2, 3, 2, 1])40assert_series_equal(result, expected)414243def test_zip_with_null_values() -> None:44s1 = pl.Series([1, None, 3, 4])45s2 = pl.Series([5, 6, None, 8])46mask = pl.Series([True, True, False, False])4748result = s1.zip_with(mask, s2)49expected = pl.Series([1, None, None, 8])50assert_series_equal(result, expected)515253def test_zip_with_length_mismatch() -> None:54s1 = pl.Series([1, 2, 3])55s2 = pl.Series([4, 5])56mask = pl.Series([True, False, True])5758with pytest.raises(pl.exceptions.ShapeError):59s1.zip_with(mask, s2)606162def test_zip_with_bad_input_type() -> None:63s1 = pl.Series([1, 2, 3])64s2 = pl.Series([4, 5, 6])65mask = pl.Series([True, False, True])6667with pytest.raises(68TypeError,69match="expected `other` .*to be a 'Series'.* not 'DataFrame'",70):71s1.zip_with(mask, pl.DataFrame(s2)) # type: ignore[arg-type]7273with pytest.raises(74TypeError,75match="expected `other` .*to be a 'Series'.* not 'LazyFrame'",76):77s1.zip_with(mask, pl.DataFrame(s2).lazy()) # type: ignore[arg-type]7879class DummySeriesSubclass(pl.Series):80pass8182s1 = DummySeriesSubclass(s1)83s2 = DummySeriesSubclass(s2)84mask = DummySeriesSubclass(mask)8586s1.zip_with(mask, s2)878889