Path: blob/main/py-polars/tests/unit/operations/test_random.py
6939 views
from __future__ import annotations12import pytest34import polars as pl5from polars.exceptions import ShapeError6from polars.testing import assert_frame_equal, assert_series_equal789def test_shuffle_group_by_reseed() -> None:10def unique_shuffle_groups(n: int, seed: int | None) -> int:11ls = [1, 2, 3] * n # 1, 2, 3, 1, 2, 3...12groups = sorted(list(range(n)) * 3) # 0, 0, 0, 1, 1, 1, ...13df = pl.DataFrame({"l": ls, "group": groups})14shuffled = df.group_by("group", maintain_order=True).agg(15pl.col("l").shuffle(seed)16)17num_unique = shuffled.group_by("l").agg(pl.lit(0)).select(pl.len())18return int(num_unique[0, 0])1920assert unique_shuffle_groups(50, None) > 1 # Astronomically unlikely.21assert (22unique_shuffle_groups(50, 0xDEADBEEF) == 123) # Fixed seed should be always the same.242526def test_sample_expr() -> None:27a = pl.Series("a", range(20))28out = pl.select(29pl.lit(a).sample(fraction=0.5, with_replacement=False, seed=1)30).to_series()3132assert out.shape == (10,)33assert out.to_list() != out.sort().to_list()34assert out.unique().shape == (10,)35assert set(out).issubset(set(a))3637out = pl.select(pl.lit(a).sample(n=10, with_replacement=False, seed=1)).to_series()38assert out.shape == (10,)39assert out.to_list() != out.sort().to_list()40assert out.unique().shape == (10,)4142# pl.set_random_seed should lead to reproducible results.43pl.set_random_seed(1)44result1 = pl.select(pl.lit(a).sample(n=10)).to_series()45pl.set_random_seed(1)46result2 = pl.select(pl.lit(a).sample(n=10)).to_series()47assert_series_equal(result1, result2)484950def test_sample_df() -> None:51df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, 7, 8], "ham": ["a", "b", "c"]})5253assert df.sample().shape == (1, 3)54assert df.sample(n=2, seed=0).shape == (2, 3)55assert df.sample(fraction=0.4, seed=0).shape == (1, 3)56assert df.sample(n=pl.Series([2]), seed=0).shape == (2, 3)57assert df.sample(fraction=pl.Series([0.4]), seed=0).shape == (1, 3)58assert df.select(pl.col("foo").sample(n=pl.Series([2]), seed=0)).shape == (2, 1)59assert df.select(pl.col("foo").sample(fraction=pl.Series([0.4]), seed=0)).shape == (601,611,62)63with pytest.raises(ValueError, match="cannot specify both `n` and `fraction`"):64df.sample(n=2, fraction=0.4)656667def test_sample_n_expr() -> None:68df = pl.DataFrame(69{70"group": [1, 1, 1, 2, 2, 2],71"val": [1, 2, 3, 2, 1, 1],72}73)7475out_df = df.sample(pl.Series([3]), seed=0)76expected_df = pl.DataFrame({"group": [2, 1, 1], "val": [1, 2, 3]})77assert_frame_equal(out_df, expected_df)7879agg_df = df.group_by("group", maintain_order=True).agg(80pl.col("val").sample(pl.col("val").max(), seed=0)81)82expected_df = pl.DataFrame({"group": [1, 2], "val": [[1, 2, 3], [2, 1]]})83assert_frame_equal(agg_df, expected_df)8485select_df = df.select(pl.col("val").sample(pl.col("val").max(), seed=0))86expected_df = pl.DataFrame({"val": [1, 2, 3]})87assert_frame_equal(select_df, expected_df)888990def test_sample_empty_df() -> None:91df = pl.DataFrame({"foo": []})9293# // If with replacement, then expect empty df94assert df.sample(n=3, with_replacement=True).shape == (0, 1)95assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1)9697# // If without replacement, then expect shape mismatch on sample_n not sample_frac98with pytest.raises(ShapeError):99df.sample(n=3, with_replacement=False)100assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1)101102103def test_sample_series() -> None:104s = pl.Series("a", [1, 2, 3, 4, 5])105106assert len(s.sample(n=2, seed=0)) == 2107assert len(s.sample(fraction=0.4, seed=0)) == 2108109assert len(s.sample(n=2, with_replacement=True, seed=0)) == 2110111# on a series of length 5, you cannot sample more than 5 items112with pytest.raises(ShapeError):113s.sample(n=10, with_replacement=False, seed=0)114# unless you use with_replacement=True115assert len(s.sample(n=10, with_replacement=True, seed=0)) == 10116117118def test_shuffle_expr() -> None:119# pl.set_random_seed should lead to reproducible results.120s = pl.Series("a", range(20))121122pl.set_random_seed(1)123result1 = pl.select(pl.lit(s).shuffle()).to_series()124125pl.set_random_seed(1)126result2 = pl.select(pl.lit(s).shuffle()).to_series()127assert_series_equal(result1, result2)128129130def test_shuffle_series() -> None:131a = pl.Series("a", [1, 2, 3])132out = a.shuffle(1)133expected = pl.Series("a", [2, 3, 1])134assert_series_equal(out, expected)135136out = pl.select(pl.lit(a).shuffle(1)).to_series()137assert_series_equal(out, expected)138139140def test_sample_16232() -> None:141k = 2142p = 0143144df = pl.DataFrame({"a": [p] * k + [1 + p], "b": [[1] * p] * k + [range(1, p + 2)]})145assert df.select(pl.col("b").list.sample(n=pl.col("a"), seed=0)).to_dict(146as_series=False147) == {"b": [[], [], [1]]}148149150