Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/py-polars/tests/unit/operations/test_random.py
6939 views
1
from __future__ import annotations
2
3
import pytest
4
5
import polars as pl
6
from polars.exceptions import ShapeError
7
from polars.testing import assert_frame_equal, assert_series_equal
8
9
10
def test_shuffle_group_by_reseed() -> None:
11
def unique_shuffle_groups(n: int, seed: int | None) -> int:
12
ls = [1, 2, 3] * n # 1, 2, 3, 1, 2, 3...
13
groups = sorted(list(range(n)) * 3) # 0, 0, 0, 1, 1, 1, ...
14
df = pl.DataFrame({"l": ls, "group": groups})
15
shuffled = df.group_by("group", maintain_order=True).agg(
16
pl.col("l").shuffle(seed)
17
)
18
num_unique = shuffled.group_by("l").agg(pl.lit(0)).select(pl.len())
19
return int(num_unique[0, 0])
20
21
assert unique_shuffle_groups(50, None) > 1 # Astronomically unlikely.
22
assert (
23
unique_shuffle_groups(50, 0xDEADBEEF) == 1
24
) # Fixed seed should be always the same.
25
26
27
def test_sample_expr() -> None:
28
a = pl.Series("a", range(20))
29
out = pl.select(
30
pl.lit(a).sample(fraction=0.5, with_replacement=False, seed=1)
31
).to_series()
32
33
assert out.shape == (10,)
34
assert out.to_list() != out.sort().to_list()
35
assert out.unique().shape == (10,)
36
assert set(out).issubset(set(a))
37
38
out = pl.select(pl.lit(a).sample(n=10, with_replacement=False, seed=1)).to_series()
39
assert out.shape == (10,)
40
assert out.to_list() != out.sort().to_list()
41
assert out.unique().shape == (10,)
42
43
# pl.set_random_seed should lead to reproducible results.
44
pl.set_random_seed(1)
45
result1 = pl.select(pl.lit(a).sample(n=10)).to_series()
46
pl.set_random_seed(1)
47
result2 = pl.select(pl.lit(a).sample(n=10)).to_series()
48
assert_series_equal(result1, result2)
49
50
51
def test_sample_df() -> None:
52
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, 7, 8], "ham": ["a", "b", "c"]})
53
54
assert df.sample().shape == (1, 3)
55
assert df.sample(n=2, seed=0).shape == (2, 3)
56
assert df.sample(fraction=0.4, seed=0).shape == (1, 3)
57
assert df.sample(n=pl.Series([2]), seed=0).shape == (2, 3)
58
assert df.sample(fraction=pl.Series([0.4]), seed=0).shape == (1, 3)
59
assert df.select(pl.col("foo").sample(n=pl.Series([2]), seed=0)).shape == (2, 1)
60
assert df.select(pl.col("foo").sample(fraction=pl.Series([0.4]), seed=0)).shape == (
61
1,
62
1,
63
)
64
with pytest.raises(ValueError, match="cannot specify both `n` and `fraction`"):
65
df.sample(n=2, fraction=0.4)
66
67
68
def test_sample_n_expr() -> None:
69
df = pl.DataFrame(
70
{
71
"group": [1, 1, 1, 2, 2, 2],
72
"val": [1, 2, 3, 2, 1, 1],
73
}
74
)
75
76
out_df = df.sample(pl.Series([3]), seed=0)
77
expected_df = pl.DataFrame({"group": [2, 1, 1], "val": [1, 2, 3]})
78
assert_frame_equal(out_df, expected_df)
79
80
agg_df = df.group_by("group", maintain_order=True).agg(
81
pl.col("val").sample(pl.col("val").max(), seed=0)
82
)
83
expected_df = pl.DataFrame({"group": [1, 2], "val": [[1, 2, 3], [2, 1]]})
84
assert_frame_equal(agg_df, expected_df)
85
86
select_df = df.select(pl.col("val").sample(pl.col("val").max(), seed=0))
87
expected_df = pl.DataFrame({"val": [1, 2, 3]})
88
assert_frame_equal(select_df, expected_df)
89
90
91
def test_sample_empty_df() -> None:
92
df = pl.DataFrame({"foo": []})
93
94
# // If with replacement, then expect empty df
95
assert df.sample(n=3, with_replacement=True).shape == (0, 1)
96
assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1)
97
98
# // If without replacement, then expect shape mismatch on sample_n not sample_frac
99
with pytest.raises(ShapeError):
100
df.sample(n=3, with_replacement=False)
101
assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1)
102
103
104
def test_sample_series() -> None:
105
s = pl.Series("a", [1, 2, 3, 4, 5])
106
107
assert len(s.sample(n=2, seed=0)) == 2
108
assert len(s.sample(fraction=0.4, seed=0)) == 2
109
110
assert len(s.sample(n=2, with_replacement=True, seed=0)) == 2
111
112
# on a series of length 5, you cannot sample more than 5 items
113
with pytest.raises(ShapeError):
114
s.sample(n=10, with_replacement=False, seed=0)
115
# unless you use with_replacement=True
116
assert len(s.sample(n=10, with_replacement=True, seed=0)) == 10
117
118
119
def test_shuffle_expr() -> None:
120
# pl.set_random_seed should lead to reproducible results.
121
s = pl.Series("a", range(20))
122
123
pl.set_random_seed(1)
124
result1 = pl.select(pl.lit(s).shuffle()).to_series()
125
126
pl.set_random_seed(1)
127
result2 = pl.select(pl.lit(s).shuffle()).to_series()
128
assert_series_equal(result1, result2)
129
130
131
def test_shuffle_series() -> None:
132
a = pl.Series("a", [1, 2, 3])
133
out = a.shuffle(1)
134
expected = pl.Series("a", [2, 3, 1])
135
assert_series_equal(out, expected)
136
137
out = pl.select(pl.lit(a).shuffle(1)).to_series()
138
assert_series_equal(out, expected)
139
140
141
def test_sample_16232() -> None:
142
k = 2
143
p = 0
144
145
df = pl.DataFrame({"a": [p] * k + [1 + p], "b": [[1] * p] * k + [range(1, p + 2)]})
146
assert df.select(pl.col("b").list.sample(n=pl.col("a"), seed=0)).to_dict(
147
as_series=False
148
) == {"b": [[], [], [1]]}
149
150