Path: blob/main/py-polars/tests/unit/series/test_scatter.py
8424 views
from datetime import date, datetime1from typing import Any23import numpy as np4import pytest56import polars as pl7from polars._typing import PolarsDataType8from polars.exceptions import ComputeError, InvalidOperationError, OutOfBoundsError9from polars.testing import assert_series_equal101112@pytest.mark.parametrize(13"input",14[15(),16[],17pl.Series(),18pl.Series(dtype=pl.Int8),19np.array([]),20],21)22def test_scatter_noop(input: Any) -> None:23s = pl.Series("s", [1, 2, 3])24s.scatter(input, 8)25assert s.to_list() == [1, 2, 3]262728def test_scatter() -> None:29s = pl.Series("s", [1, 2, 3])3031# set new values, one index at a time32s.scatter(0, 8)33s.scatter([1], None)34assert s.to_list() == [8, None, 3]3536# set new value at multiple indexes in one go37s.scatter([0, 2], None)38assert s.to_list() == [None, None, None]3940# try with different series dtype41s = pl.Series("s", ["a", "b", "c"])42s.scatter((1, 2), "x")43assert s.to_list() == ["a", "x", "x"]44assert s.scatter([0, 2], 0.12345).to_list() == ["0.12345", "x", "0.12345"]4546# set multiple values47s = pl.Series(["z", "z", "z"])48assert s.scatter([0, 1], ["a", "b"]).to_list() == ["a", "b", "z"]49s = pl.Series([True, False, True])50assert s.scatter([0, 1], [False, True]).to_list() == [False, True, True]5152# set negative indices53a = pl.Series("r", range(5))54a[-2] = None55a[-5] = None56assert a.to_list() == [None, 1, 2, None, 4]5758a = pl.Series("x", [1, 2])59with pytest.raises(OutOfBoundsError):60a[-100] = None61assert_series_equal(a, pl.Series("x", [1, 2]))626364def test_index_with_None_errors_16905() -> None:65s = pl.Series("s", [1, 2, 3])66with pytest.raises(ComputeError, match="index values should not be null"):67s[[1, None]] = 568# The error doesn't trash the series, as it used to:69assert_series_equal(s, pl.Series("s", [1, 2, 3]))707172def test_object_dtype_16905() -> None:73obj = object()74s = pl.Series("s", [obj, 27], dtype=pl.Object)75# This operation is not semantically wrong, it might be supported in the76# future, but for now it isn't.77with pytest.raises(InvalidOperationError):78s[0] = 579# The error doesn't trash the series, as it used to:80assert s.dtype.is_object()81assert s.name == "s"82assert s.to_list() == [obj, 27]838485def test_scatter_datetime() -> None:86s = pl.Series("dt", [None, datetime(2024, 1, 31)])87result = s.scatter(0, datetime(2022, 2, 2))88expected = pl.Series("dt", [datetime(2022, 2, 2), datetime(2024, 1, 31)])89assert_series_equal(result, expected)909192def test_scatter_logical_all_null() -> None:93s = pl.Series("dt", [None, None], dtype=pl.Date)94result = s.scatter(0, date(2022, 2, 2))95expected = pl.Series("dt", [date(2022, 2, 2), None])96assert_series_equal(result, expected)979899def test_scatter_categorical_21175() -> None:100s = pl.Series(["a", "b", "c"], dtype=pl.Categorical)101assert_series_equal(102s.scatter(0, "b"), pl.Series(["b", "b", "c"], dtype=pl.Categorical)103)104v = pl.Series(["v"], dtype=pl.Categorical)105assert_series_equal(106s.scatter([0, 2], v), pl.Series(["v", "b", "v"], dtype=pl.Categorical)107)108109with pytest.raises(InvalidOperationError):110s.scatter(1, 2)111112113def test_scatter_enum() -> None:114e = pl.Enum(["a", "b", "c", "v"])115s = pl.Series(["a", "b", "c"], dtype=e)116assert_series_equal(s.scatter(0, "b"), pl.Series(["b", "b", "c"], dtype=e))117v = pl.Series(["v"], dtype=pl.Categorical)118assert_series_equal(s.scatter([0, 2], v), pl.Series(["v", "b", "v"], dtype=e))119120with pytest.raises(InvalidOperationError):121s.scatter(1, "d")122123with pytest.raises(InvalidOperationError):124s.scatter(1, 2)125126127@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b"])])128def test_scatter_null(dtype: PolarsDataType) -> None:129s = pl.Series("a", ["a", "b"], dtype=dtype)130result = s.scatter(0, None)131expected = pl.Series("a", [None, "b"], dtype=dtype)132assert_series_equal(result, expected)133134135def test_scatter_decimal_25869() -> None:136s = pl.Series([1, 2, 3], dtype=pl.Decimal(scale=10))137assert_series_equal(138s.scatter(0, 15), pl.Series([15, 2, 3], dtype=pl.Decimal(scale=10))139)140141with pytest.raises(InvalidOperationError):142s.scatter(1, "test")143144145