Path: blob/main/py-polars/tests/unit/functions/test_col.py
8424 views
from __future__ import annotations12import polars as pl3from polars import col4from polars.datatypes.group import NUMERIC_DTYPES5from polars.testing import assert_frame_equal678def test_col_as_attribute() -> None:9df = pl.DataFrame({"lower": 1, "UPPER": 2, "_underscored": 3})1011result = df.select(col.lower, col.UPPER, col._underscored)12expected = df.select("lower", "UPPER", "_underscored")13assert_frame_equal(result, expected)141516def test_col_as_attribute_edge_cases() -> None:17df = pl.DataFrame(18{19"__misc": "x",20"__wrapped__col": "y",21"_other__col__": "z",22}23)24for select_cols in (25(pl.col("_other__col__"), pl.col("__wrapped__col"), pl.col("__misc")),26(pl.col._other__col__, pl.col.__wrapped__col, pl.col.__misc),27):28assert df.select(select_cols).columns == [29"_other__col__",30"__wrapped__col",31"__misc",32]333435def test_col_as_attribute_class_mangling_25129() -> None:36# note: we have to run this test in a subprocess to prevent pytest37# itself from managing to inject "_pytestfixturefunction" as the38# col name/attribute, where we can't recover the original name39import subprocess40import sys4142out = subprocess.check_output(43[44sys.executable,45"-c",46"""\47from sys import version_info48import polars as pl49df = pl.DataFrame({"__foo": [0]})5051class Mangler:52def __init__(self):53self._selected = df.select(pl.col.__foo)5455def foo(self):56return df.select(pl.col.__foo)5758@classmethod59def misc(cls):60def _nested():61return df.select(pl.col.__foo)62return _nested()6364@staticmethod65def indirect():66return Mangler.misc()6768@staticmethod69def testing1234():70return df.select(pl.col.__foo)717273# detect mangling in init/instancemethod74assert Mangler()._selected.columns == ["__foo"]75assert Mangler().foo().columns == ["__foo"]7677# additionally detect mangling in classmethod/staticmethod78if version_info >= (3, 11):79assert Mangler.misc().columns == ["__foo"]80assert Mangler.indirect().columns == ["__foo"]81assert Mangler.testing1234().columns == ["__foo"]8283print("OK", end="")84""",85],86)87assert out == b"OK"888990def test_col_select() -> None:91df = pl.DataFrame(92{93"ham": [1, 2, 3],94"hamburger": [11, 22, 33],95"foo": [3, 2, 1],96"bar": ["a", "b", "c"],97}98)99100# Single column101assert df.select(pl.col("foo")).columns == ["foo"]102103# Regex104assert df.select(pl.col("*")).columns == ["ham", "hamburger", "foo", "bar"]105assert df.select(pl.col("^ham.*$")).columns == ["ham", "hamburger"]106assert df.select(pl.col("*").exclude("ham")).columns == ["hamburger", "foo", "bar"]107108# Multiple inputs109assert df.select(pl.col(["hamburger", "foo"])).columns == ["hamburger", "foo"]110assert df.select(pl.col("hamburger", "foo")).columns == ["hamburger", "foo"]111assert df.select(pl.col(pl.Series(["ham", "foo"]))).columns == ["ham", "foo"]112113# Dtypes114assert df.select(pl.col(pl.String)).columns == ["bar"]115for dtype_col in (116pl.col(NUMERIC_DTYPES),117pl.col(pl.Int64, pl.Float64),118):119assert df.select(dtype_col).columns == ["ham", "hamburger", "foo"]120121122def test_col_series_selection() -> None:123ldf = pl.LazyFrame({"a": [1], "b": [1], "c": [1]})124srs = pl.Series(["b", "c"])125126assert ldf.select(pl.col(srs)).collect_schema().names() == ["b", "c"]127128129