Path: blob/main/py-polars/tests/unit/operations/test_fill_null.py
6939 views
import datetime12import pytest34import polars as pl5from polars.testing import assert_frame_equal, assert_series_equal678def test_fill_null_minimal_upcast_4056() -> None:9df = pl.DataFrame({"a": [-1, 2, None]})10df = df.with_columns(pl.col("a").cast(pl.Int8))11assert df.with_columns(pl.col(pl.Int8).fill_null(-1)).dtypes[0] == pl.Int812assert df.with_columns(pl.col(pl.Int8).fill_null(-1000)).dtypes[0] == pl.Int16131415def test_fill_enum_upcast() -> None:16dtype = pl.Enum(["a", "b"])17s = pl.Series(["a", "b", None], dtype=dtype)18s_filled = s.fill_null("b")19expected = pl.Series(["a", "b", "b"], dtype=dtype)20assert s_filled.dtype == dtype21assert_series_equal(s_filled, expected)222324def test_fill_null_static_schema_4843() -> None:25df1 = pl.DataFrame(26{27"a": [1, 2, None],28"b": [1, None, 4],29}30).lazy()3132df2 = df1.select([pl.col(pl.Int64).fill_null(0)])33df3 = df2.select(pl.col(pl.Int64))34assert df3.collect_schema() == {"a": pl.Int64, "b": pl.Int64}353637def test_fill_null_non_lit() -> None:38df = pl.DataFrame(39{40"a": pl.Series([1, None], dtype=pl.Int32),41"b": pl.Series([None, 2], dtype=pl.UInt32),42"c": pl.Series([None, 2], dtype=pl.Int64),43"d": pl.Series([None, 2], dtype=pl.Decimal),44}45)46assert df.fill_null(0).select(pl.all().null_count()).transpose().sum().item() == 0474849def test_fill_null_f32_with_lit() -> None:50# ensure the literal integer does not upcast the f32 to an f6451df = pl.DataFrame({"a": [1.1, 1.2]}, schema=[("a", pl.Float32)])52assert df.fill_null(value=0).dtypes == [pl.Float32]535455def test_fill_null_lit_() -> None:56df = pl.DataFrame(57{58"a": pl.Series([1, None], dtype=pl.Int32),59"b": pl.Series([None, 2], dtype=pl.UInt32),60"c": pl.Series([None, 2], dtype=pl.Int64),61}62)63assert (64df.fill_null(pl.lit(0)).select(pl.all().null_count()).transpose().sum().item()65== 066)676869def test_fill_null_decimal_with_int_14331() -> None:70s = pl.Series("a", ["1.1", None], dtype=pl.Decimal(precision=None, scale=5))71result = s.fill_null(0)72expected = pl.Series("a", ["1.1", "0.0"], dtype=pl.Decimal(precision=None, scale=5))73assert_series_equal(result, expected)747576def test_fill_null_date_with_int_11362() -> None:77match = "got invalid or ambiguous dtypes"7879s = pl.Series([datetime.date(2000, 1, 1)])80with pytest.raises(pl.exceptions.InvalidOperationError, match=match):81s.fill_null(0)8283s = pl.Series([None], dtype=pl.Date)84with pytest.raises(pl.exceptions.InvalidOperationError, match=match):85s.fill_null(1)868788def test_fill_null_int_dtype_15546() -> None:89df = pl.Series("a", [1, 2, None], dtype=pl.Int8).to_frame().lazy()90result = df.fill_null(0).collect()91expected = pl.Series("a", [1, 2, 0], dtype=pl.Int8).to_frame()92assert_frame_equal(result, expected)939495def test_fill_null_with_list_10869() -> None:96assert_series_equal(97pl.Series([[1], None]).fill_null([2]),98pl.Series([[1], [2]]),99)100101match = "failed to determine supertype"102with pytest.raises(pl.exceptions.SchemaError, match=match):103pl.Series([1, None]).fill_null([2])104105106def test_unequal_lengths_22018() -> None:107with pytest.raises(pl.exceptions.ShapeError):108pl.Series([1, None]).fill_null(pl.Series([1] * 3))109with pytest.raises(pl.exceptions.ShapeError):110pl.Series([1, 2]).fill_null(pl.Series([1] * 3))111112113def test_self_broadcast() -> None:114assert_series_equal(115pl.Series([1]).fill_null(pl.Series(range(3))),116pl.Series([1] * 3),117)118119assert_series_equal(120pl.Series([None]).fill_null(pl.Series(range(3))),121pl.Series(range(3)),122)123124125