Path: blob/main/py-polars/tests/unit/functions/test_lit.py
6939 views
# mypy: disable-error-code="redundant-expr"1from __future__ import annotations23import enum4import sys5from datetime import datetime, timedelta6from decimal import Decimal7from typing import TYPE_CHECKING, Any89import numpy as np10import pytest11from hypothesis import given1213import polars as pl14from polars.testing import assert_frame_equal15from polars.testing.parametric.strategies import series16from polars.testing.parametric.strategies.data import datetimes1718if TYPE_CHECKING:19from polars._typing import PolarsDataType202122if sys.version_info >= (3, 11):23from enum import StrEnum2425PyStrEnum: type[enum.Enum] | None = StrEnum26else:27PyStrEnum = None282930@pytest.mark.parametrize(31"input",32[33[[1, 2], [3, 4, 5]],34[1, 2, 3],35],36)37def test_lit_list_input(input: list[Any]) -> None:38df = pl.DataFrame({"a": [1, 2]})39result = df.with_columns(pl.lit(input).first())40expected = pl.DataFrame({"a": [1, 2], "literal": [input, input]})41assert_frame_equal(result, expected)424344@pytest.mark.parametrize(45"input",46[47([1, 2], [3, 4, 5]),48(1, 2, 3),49],50)51def test_lit_tuple_input(input: tuple[Any, ...]) -> None:52df = pl.DataFrame({"a": [1, 2]})53result = df.with_columns(pl.lit(input).first())5455expected = pl.DataFrame({"a": [1, 2], "literal": [list(input), list(input)]})56assert_frame_equal(result, expected)575859def test_lit_numpy_array_input() -> None:60df = pl.DataFrame({"a": [1, 2]})61input = np.array([3, 4])6263result = df.with_columns(pl.lit(input, dtype=pl.Int64))6465expected = pl.DataFrame({"a": [1, 2], "literal": [3, 4]})66assert_frame_equal(result, expected)676869def test_lit_ambiguous_datetimes_11379() -> None:70df = pl.DataFrame(71{72"ts": pl.datetime_range(73datetime(2020, 10, 25),74datetime(2020, 10, 25, 2),75"1h",76time_zone="Europe/London",77eager=True,78)79}80)81for i in range(df.height):82result = df.filter(pl.col("ts") >= df["ts"][i])83expected = df[i:]84assert_frame_equal(result, expected)858687def test_list_datetime_11571() -> None:88sec_np_ns = np.timedelta64(1_000_000_000, "ns")89sec_np_us = np.timedelta64(1_000_000, "us")90assert pl.select(pl.lit(sec_np_ns))[0, 0] == timedelta(seconds=1)91assert pl.select(pl.lit(sec_np_us))[0, 0] == timedelta(seconds=1)929394@pytest.mark.parametrize(95("input", "dtype"),96[97pytest.param(-(2**31), pl.Int32, id="i32 min"),98pytest.param(-(2**31) - 1, pl.Int64, id="below i32 min"),99pytest.param(2**31 - 1, pl.Int32, id="i32 max"),100pytest.param(2**31, pl.Int64, id="above i32 max"),101pytest.param(2**63 - 1, pl.Int64, id="i64 max"),102pytest.param(2**63, pl.UInt64, id="above i64 max"),103],104)105def test_lit_int_return_type(input: int, dtype: PolarsDataType) -> None:106assert pl.select(pl.lit(input)).to_series().dtype == dtype107108109def test_lit_unsupported_type() -> None:110with pytest.raises(111TypeError,112match="cannot create expression literal for value of type LazyFrame",113):114pl.lit(pl.LazyFrame({"a": [1, 2, 3]}))115116117@pytest.mark.parametrize(118"EnumBase",119[120(enum.Enum,),121(str, enum.Enum),122*([(PyStrEnum,)] if PyStrEnum is not None else []),123],124)125def test_lit_enum_input_16668(EnumBase: tuple[type, ...]) -> None:126# https://github.com/pola-rs/polars/issues/16668127128class State(*EnumBase): # type: ignore[misc]129NSW = "New South Wales"130QLD = "Queensland"131VIC = "Victoria"132133# validate that frame schema has inferred the enum134df = pl.DataFrame({"state": [State.NSW, State.VIC]})135assert df.schema == {136"state": pl.Enum(["New South Wales", "Queensland", "Victoria"])137}138139# check use of enum as lit/constraint140value = State.VIC141expected = "Victoria"142143for lit_value in (144pl.lit(value),145pl.lit(value.value), # type: ignore[attr-defined]146):147assert pl.select(lit_value).item() == expected148assert df.filter(state=value).item() == expected149assert df.filter(state=lit_value).item() == expected150151assert df.filter(pl.col("state") == State.QLD).is_empty()152assert df.filter(pl.col("state") != State.QLD).height == 2153154155@pytest.mark.parametrize(156"EnumBase",157[158(enum.Enum,),159(enum.Flag,),160(enum.IntEnum,),161(enum.IntFlag,),162(int, enum.Enum),163],164)165def test_lit_enum_input_non_string(EnumBase: tuple[type, ...]) -> None:166# https://github.com/pola-rs/polars/issues/16668167168class Number(*EnumBase): # type: ignore[misc]169ONE = 1170TWO = 2171172value = Number.ONE173174result = pl.lit(value)175assert pl.select(result).dtypes[0] == pl.Int32176assert pl.select(result).item() == 1177178result = pl.lit(value, dtype=pl.Int8)179assert pl.select(result).dtypes[0] == pl.Int8180assert pl.select(result).item() == 1181182183@given(value=datetimes("ns"))184def test_datetime_ns(value: datetime) -> None:185result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0]186assert result == value187188189@given(value=datetimes("us"))190def test_datetime_us(value: datetime) -> None:191result = pl.select(pl.lit(value, dtype=pl.Datetime("us")))["literal"][0]192assert result == value193result = pl.select(pl.lit(value, dtype=pl.Datetime))["literal"][0]194assert result == value195196197@given(value=datetimes("ms"))198def test_datetime_ms(value: datetime) -> None:199result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0]200expected_microsecond = value.microsecond // 1000 * 1000201assert result == value.replace(microsecond=expected_microsecond)202203204@pytest.mark.may_fail_cloud # @cloud-decimal205def test_lit_decimal() -> None:206value = Decimal("0.1")207208expr = pl.lit(value)209df = pl.select(expr)210result = df.item()211212assert df.dtypes[0] == pl.Decimal(None, 1)213assert result == value214215216def test_lit_string_float() -> None:217value = 3.2218219expr = pl.lit(value, dtype=pl.Utf8)220df = pl.select(expr)221result = df.item()222223assert df.dtypes[0] == pl.String224assert result == str(value)225226227@pytest.mark.may_fail_cloud # @cloud-decimal228@given(s=series(min_size=1, max_size=1, allow_null=False, allowed_dtypes=pl.Decimal))229def test_lit_decimal_parametric(s: pl.Series) -> None:230scale = s.dtype.scale # type: ignore[attr-defined]231value = s.item()232233expr = pl.lit(value)234df = pl.select(expr)235result = df.item()236237assert df.dtypes[0] == pl.Decimal(None, scale)238assert result == value239240241@pytest.mark.parametrize(242"item",243[pytest.param({}, marks=pytest.mark.may_fail_cloud), {"foo": 1}],244)245def test_lit_structs(item: Any) -> None:246assert pl.select(pl.lit(item)).to_dict(as_series=False) == {"literal": [item]}247248249@pytest.mark.parametrize(250("value", "expected_dtype"),251[252(np.float32(1.2), pl.Float32),253(np.float64(1.2), pl.Float64),254(np.int8(1), pl.Int8),255(np.uint8(1), pl.UInt8),256(np.int16(1), pl.Int16),257(np.uint16(1), pl.UInt16),258(np.int32(1), pl.Int32),259(np.uint32(1), pl.UInt32),260(np.int64(1), pl.Int64),261(np.uint64(1), pl.UInt64),262],263)264def test_numpy_lit(value: Any, expected_dtype: PolarsDataType) -> None:265result = pl.select(pl.lit(value)).get_column("literal")266assert result.dtype == expected_dtype267268269