Path: blob/main/py-polars/tests/unit/sql/test_literals.py
6939 views
from __future__ import annotations12from datetime import date, datetime, timedelta34import pytest56import polars as pl7from polars.exceptions import SQLInterfaceError, SQLSyntaxError8from polars.testing import assert_frame_equal91011def test_bit_hex_literals() -> None:12with pl.SQLContext(df=None, eager=True) as ctx:13out = ctx.execute(14"""15SELECT *,16-- bit strings17b'' AS b0,18b'1001' AS b1,19b'11101011' AS b2,20b'1111110100110010' AS b3,21-- hex strings22x'' AS x0,23x'FF' AS x1,24x'4142' AS x2,25x'DeadBeef' AS x3,26FROM df27"""28)2930assert out.to_dict(as_series=False) == {31"b0": [b""],32"b1": [b"\t"],33"b2": [b"\xeb"],34"b3": [b"\xfd2"],35"x0": [b""],36"x1": [b"\xff"],37"x2": [b"AB"],38"x3": [b"\xde\xad\xbe\xef"],39}404142def test_bit_hex_filter() -> None:43df = pl.DataFrame(44{"bin": [b"\x01", b"\x02", b"\x03", b"\x04"], "val": [9, 8, 7, 6]}45)46with pl.SQLContext(test=df) as ctx:47for two in ("b'10'", "x'02'", "'\x02'", "b'0010'"):48out = ctx.execute(f"SELECT val FROM test WHERE bin > {two}", eager=True)49assert out.to_series().to_list() == [7, 6]505152def test_bit_hex_errors() -> None:53with pl.SQLContext(test=None) as ctx:54with pytest.raises(55SQLSyntaxError,56match="bit string literal should contain only 0s and 1s",57):58ctx.execute("SELECT b'007' FROM test", eager=True)5960with pytest.raises(61SQLSyntaxError,62match="hex string literal must have an even number of digits",63):64ctx.execute("SELECT x'00F' FROM test", eager=True)6566with pytest.raises(67SQLSyntaxError,68match="hex string literal must have an even number of digits",69):70pl.sql_expr("colx IN (x'FF',x'123')")7172with pytest.raises(73SQLInterfaceError,74match=r'NationalStringLiteral\("hmmm"\) is not a supported literal',75):76pl.sql_expr("N'hmmm'")777879def test_bit_hex_membership() -> None:80df = pl.DataFrame(81{82"x": [b"\x05", b"\xff", b"\xcc", b"\x0b"],83"y": [1, 2, 3, 4],84}85)86# this checks the internal `visit_any_value` codepath87for values in (88"b'0101', b'1011'",89"x'05', x'0b'",90):91dff = df.filter(pl.sql_expr(f"x IN ({values})"))92assert dff["y"].to_list() == [1, 4]939495def test_dollar_quoted_literals() -> None:96df = pl.sql(97"""98SELECT99$$xyz$$ AS dq1,100$q$xyz$q$ AS dq2,101$tag$xyz$tag$ AS dq3,102$QUOTE$xyz$QUOTE$ AS dq4,103"""104).collect()105assert df.to_dict(as_series=False) == {f"dq{n}": ["xyz"] for n in range(1, 5)}106107df = pl.sql("SELECT $$x$z$$ AS dq").collect()108assert df.item() == "x$z"109110111def test_fixed_intervals() -> None:112with pl.SQLContext(df=None, eager=True) as ctx:113out = ctx.execute(114"""115SELECT116-- short form with/without spaces117INTERVAL '1w2h3m4s' AS i1,118INTERVAL '100ms 100us' AS i2,119-- long form with/without commas (case-insensitive)120INTERVAL '1 week, 2 hours, 3 minutes, 4 seconds' AS i3121FROM df122"""123)124expected = pl.DataFrame(125{126"i1": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],127"i2": [timedelta(microseconds=100100)],128"i3": [timedelta(weeks=1, hours=2, minutes=3, seconds=4)],129},130).cast(pl.Duration("ns"))131132assert_frame_equal(expected, out)133134# TODO: negative intervals135with pytest.raises(136SQLInterfaceError,137match="minus signs are not yet supported in interval strings; found '-7d'",138):139ctx.execute("SELECT INTERVAL '-7d' AS one_week_ago FROM df")140141with pytest.raises(142SQLSyntaxError,143match="unary ops are not valid on interval strings; found -'7d'",144):145ctx.execute("SELECT INTERVAL -'7d' AS one_week_ago FROM df")146147with pytest.raises(148SQLSyntaxError,149match="fixed-duration interval cannot contain years, quarters, or months",150):151ctx.execute("SELECT INTERVAL '1 quarter 1 month' AS q FROM df")152153154def test_interval_offsets() -> None:155df = pl.DataFrame(156{157"dtm": [158datetime(1899, 12, 31, 8),159datetime(1999, 6, 8, 10, 30),160datetime(2010, 5, 7, 20, 20, 20),161],162"dt": [163date(1950, 4, 10),164date(2048, 1, 20),165date(2026, 8, 5),166],167}168)169170out = df.sql(171"""172SELECT173dtm + INTERVAL '2 months, 30 minutes' AS dtm_plus_2mo30m,174dt + INTERVAL '100 years' AS dt_plus_100y,175dt - INTERVAL '1 quarter' AS dt_minus_1q176FROM self177ORDER BY 1178"""179)180assert out.to_dict(as_series=False) == {181"dtm_plus_2mo30m": [182datetime(1900, 2, 28, 8, 30),183datetime(1999, 8, 8, 11, 0),184datetime(2010, 7, 7, 20, 50, 20),185],186"dt_plus_100y": [187date(2050, 4, 10),188date(2148, 1, 20),189date(2126, 8, 5),190],191"dt_minus_1q": [192date(1950, 1, 10),193date(2047, 10, 20),194date(2026, 5, 5),195],196}197198199@pytest.mark.parametrize(200("interval_comparison", "expected_result"),201[202("INTERVAL '3 days' <= INTERVAL '3 days, 1 microsecond'", True),203("INTERVAL '3 days, 1 microsecond' <= INTERVAL '3 days'", False),204("INTERVAL '3 months' >= INTERVAL '3 months'", True),205("INTERVAL '2 quarters' < INTERVAL '2 quarters'", False),206("INTERVAL '2 quarters' > INTERVAL '2 quarters'", False),207("INTERVAL '3 years' <=> INTERVAL '3 years'", True),208("INTERVAL '3 years' == INTERVAL '1008 weeks'", False),209("INTERVAL '8 weeks' != INTERVAL '2 months'", True),210("INTERVAL '8 weeks' = INTERVAL '2 months'", False),211("INTERVAL '1 year' != INTERVAL '365 days'", True),212("INTERVAL '1 year' = INTERVAL '1 year'", True),213],214)215def test_interval_comparisons(interval_comparison: str, expected_result: bool) -> None:216with pl.SQLContext() as ctx:217res = ctx.execute(f"SELECT {interval_comparison} AS res")218assert res.collect().to_dict(as_series=False) == {"res": [expected_result]}219220221def test_select_literals_no_table() -> None:222res = pl.sql("SELECT 1 AS one, '2' AS two, 3.0 AS three", eager=True)223assert res.to_dict(as_series=False) == {224"one": [1],225"two": ["2"],226"three": [3.0],227}228229230def test_select_from_table_with_reserved_names() -> None:231select = pl.DataFrame({"select": [1, 2, 3], "from": [4, 5, 6]}) # noqa: F841232out = pl.sql(233"""234SELECT "from", "select"235FROM "select"236WHERE "from" >= 5 AND "select" % 2 != 1237""",238eager=True,239)240assert out.rows() == [(5, 2)]241242243