Path: blob/main/py-polars/tests/unit/operations/test_replace.py
6939 views
from __future__ import annotations12from typing import Any34import pytest56import polars as pl7from polars.exceptions import InvalidOperationError8from polars.testing import assert_frame_equal, assert_series_equal91011@pytest.fixture(scope="module")12def str_mapping() -> dict[str | None, str]:13return {14"CA": "Canada",15"DE": "Germany",16"FR": "France",17None: "Not specified",18}192021def test_replace_str_to_str(str_mapping: dict[str | None, str]) -> None:22df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]})23result = df.select(replaced=pl.col("country_code").replace(str_mapping))24expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]})25assert_frame_equal(result, expected)262728def test_replace_enum() -> None:29dtype = pl.Enum(["a", "b", "c", "d"])30s = pl.Series(["a", "b", "c"], dtype=dtype)31old = ["a", "b"]32new = pl.Series(["c", "d"], dtype=dtype)3334result = s.replace(old, new)3536expected = pl.Series(["c", "d", "c"], dtype=dtype)37assert_series_equal(result, expected)383940def test_replace_enum_to_str() -> None:41dtype = pl.Enum(["a", "b", "c", "d"])42s = pl.Series(["a", "b", "c"], dtype=dtype)4344result = s.replace({"a": "c", "b": "d"})4546expected = pl.Series(["c", "d", "c"], dtype=dtype)47assert_series_equal(result, expected)484950def test_replace_cat_to_cat(str_mapping: dict[str | None, str]) -> None:51lf = pl.LazyFrame(52{"country_code": ["FR", None, "ES", "DE"]},53schema={"country_code": pl.Categorical},54)55old = pl.Series(["CA", "DE", "FR", None], dtype=pl.Categorical)56new = pl.Series(57["Canada", "Germany", "France", "Not specified"], dtype=pl.Categorical58)5960result = lf.select(replaced=pl.col("country_code").replace(old, new))6162expected = pl.LazyFrame(63{"replaced": ["France", "Not specified", "ES", "Germany"]},64schema_overrides={"replaced": pl.Categorical},65)66assert_frame_equal(result, expected)676869def test_replace_invalid_old_dtype() -> None:70lf = pl.LazyFrame({"a": [1, 2, 3]})71mapping = {"a": 10, "b": 20}72with pytest.raises(73InvalidOperationError, match="conversion from `str` to `i64` failed"74):75lf.select(pl.col("a").replace(mapping)).collect()767778def test_replace_int_to_int() -> None:79df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})80mapping = {1: 5, 3: 7}81result = df.select(replaced=pl.col("int").replace(mapping))82expected = pl.DataFrame(83{"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int16}84)85assert_frame_equal(result, expected)868788def test_replace_int_to_int_keep_dtype() -> None:89df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})90old = [1, 3]91new = pl.Series([5, 7], dtype=pl.Int16)9293result = df.select(replaced=pl.col("int").replace(old, new))94expected = pl.DataFrame(95{"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int16}96)97assert_frame_equal(result, expected)9899100def test_replace_int_to_str() -> None:101df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})102mapping = {1: "b", 3: "d"}103with pytest.raises(104InvalidOperationError, match="conversion from `str` to `i16` failed"105):106df.select(replaced=pl.col("int").replace(mapping))107108109def test_replace_int_to_str_with_null() -> None:110df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})111mapping = {1: "b", 3: "d", None: "e"}112with pytest.raises(113InvalidOperationError, match="conversion from `str` to `i16` failed"114):115df.select(replaced=pl.col("int").replace(mapping))116117118def test_replace_empty_mapping() -> None:119df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16})120mapping: dict[Any, Any] = {}121result = df.select(pl.col("int").replace(mapping))122assert_frame_equal(result, df)123124125def test_replace_mapping_different_dtype_str_int() -> None:126df = pl.DataFrame({"int": [None, "1", None, "3"]})127mapping = {1: "b", 3: "d"}128129result = df.select(pl.col("int").replace(mapping))130expected = pl.DataFrame({"int": [None, "b", None, "d"]})131assert_frame_equal(result, expected)132133134def test_replace_mapping_different_dtype_map_none() -> None:135df = pl.DataFrame({"int": [None, "1", None, "3"]})136mapping = {1: "b", 3: "d", None: "e"}137result = df.select(pl.col("int").replace(mapping))138expected = pl.DataFrame({"int": ["e", "b", "e", "d"]})139assert_frame_equal(result, expected)140141142def test_replace_mapping_different_dtype_str_float() -> None:143df = pl.DataFrame({"int": [None, "1", None, "3"]})144mapping = {1.0: "b", 3.0: "d"}145146result = df.select(pl.col("int").replace(mapping))147assert_frame_equal(result, df)148149150# https://github.com/pola-rs/polars/issues/7132151def test_replace_str_to_str_replace_all() -> None:152df = pl.DataFrame({"text": ["abc"]})153mapping = {"abc": "123"}154result = df.select(pl.col("text").replace(mapping).str.replace_all("1", "-"))155expected = pl.DataFrame({"text": ["-23"]})156assert_frame_equal(result, expected)157158159@pytest.fixture(scope="module")160def int_mapping() -> dict[int, int]:161return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55}162163164def test_replace_int_to_int1(int_mapping: dict[int, int]) -> None:165s = pl.Series([-1, 22, None, 44, -5])166result = s.replace(int_mapping)167expected = pl.Series([-1, 22, None, 44, -5])168assert_series_equal(result, expected)169170171def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None:172s = pl.Series([-1, 22, None, 44, -5])173result = s.replace(int_mapping)174expected = pl.Series([-1, 22, None, 44, -5])175assert_series_equal(result, expected)176177178# https://github.com/pola-rs/polars/issues/12728179def test_replace_str_to_int2() -> None:180s = pl.Series(["a", "b"])181mapping = {"a": 1, "b": 2}182result = s.replace(mapping)183expected = pl.Series(["1", "2"])184assert_series_equal(result, expected)185186187def test_replace_str_to_bool_without_default() -> None:188s = pl.Series(["True", "False", "False", None])189mapping = {"True": True, "False": False}190result = s.replace(mapping)191expected = pl.Series(["true", "false", "false", None])192assert_series_equal(result, expected)193194195def test_replace_old_new() -> None:196s = pl.Series([1, 2, 2, 3])197result = s.replace(2, 9)198expected = s = pl.Series([1, 9, 9, 3])199assert_series_equal(result, expected)200201202def test_replace_old_new_many_to_one() -> None:203s = pl.Series([1, 2, 2, 3])204result = s.replace([2, 3], 9)205expected = s = pl.Series([1, 9, 9, 9])206assert_series_equal(result, expected)207208209def test_replace_old_new_mismatched_lengths() -> None:210s = pl.Series([1, 2, 2, 3, 4])211with pytest.raises(InvalidOperationError):212s.replace([2, 3, 4], [8, 9])213214215def test_replace_fast_path_one_to_one() -> None:216lf = pl.LazyFrame({"a": [1, 2, 2, 3]})217result = lf.select(pl.col("a").replace(2, 100))218expected = pl.LazyFrame({"a": [1, 100, 100, 3]})219assert_frame_equal(result, expected)220221222def test_replace_fast_path_one_null_to_one() -> None:223# https://github.com/pola-rs/polars/issues/13391224lf = pl.LazyFrame({"a": [1, None]})225result = lf.select(pl.col("a").replace(None, 100))226expected = pl.LazyFrame({"a": [1, 100]})227assert_frame_equal(result, expected)228229230def test_replace_fast_path_many_with_null_to_one() -> None:231lf = pl.LazyFrame({"a": [1, 2, None]})232result = lf.select(pl.col("a").replace([1, None], 100))233expected = pl.LazyFrame({"a": [100, 2, 100]})234assert_frame_equal(result, expected)235236237def test_replace_fast_path_many_to_one() -> None:238lf = pl.LazyFrame({"a": [1, 2, 2, 3]})239result = lf.select(pl.col("a").replace([2, 3], 100))240expected = pl.LazyFrame({"a": [1, 100, 100, 100]})241assert_frame_equal(result, expected)242243244@pytest.mark.parametrize(245("old", "new"),246[247([2, 2], 100),248([2, 2], [100, 200]),249([2, 2], [100, 100]),250],251)252def test_replace_duplicates_old(old: list[int], new: int | list[int]) -> None:253s = pl.Series([1, 2, 3, 2, 3])254with pytest.raises(255InvalidOperationError,256match="`old` input for `replace` must not contain duplicates",257):258s.replace(old, new)259260261def test_replace_duplicates_new() -> None:262s = pl.Series([1, 2, 3, 2, 3])263result = s.replace([1, 2], [100, 100])264expected = s = pl.Series([100, 100, 3, 100, 3])265assert_series_equal(result, expected)266267268def test_replace_return_dtype_deprecated() -> None:269s = pl.Series([1, 2, 3])270with pytest.deprecated_call():271result = s.replace(1, 10, return_dtype=pl.Int8)272expected = pl.Series([10, 2, 3], dtype=pl.Int8)273assert_series_equal(result, expected)274275276def test_replace_default_deprecated() -> None:277s = pl.Series([1, 2, 3])278with pytest.deprecated_call():279result = s.replace(1, 10, default=None)280expected = pl.Series([10, None, None], dtype=pl.Int32)281assert_series_equal(result, expected)282283284def test_replace_single_argument_not_mapping() -> None:285df = pl.DataFrame({"a": ["a", "b", "c"]})286with pytest.raises(287TypeError,288match="`new` argument is required if `old` argument is not a Mapping type",289):290df.select(pl.col("a").replace("b"))291292293