Path: blob/main/crates/polars-core/src/chunked_array/random.rs
6940 views
use num_traits::{Float, NumCast};1use polars_error::to_compute_err;2use rand::distr::Bernoulli;3use rand::prelude::*;4use rand::seq::index::IndexVec;5use rand_distr::{Normal, StandardNormal, StandardUniform, Uniform};67use crate::prelude::DataType::Float64;8use crate::prelude::*;9use crate::random::get_global_random_u64;10use crate::utils::NoNull;1112fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {13if len == 0 {14return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);15}16let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));17let dist = Uniform::new(0, len as IdxSize).unwrap();18(0..n as IdxSize)19.map(move |_| dist.sample(&mut rng))20.collect_trusted::<NoNull<IdxCa>>()21.into_inner()22}2324fn create_rand_index_no_replacement(25n: usize,26len: usize,27seed: Option<u64>,28shuffle: bool,29) -> IdxCa {30let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));31let mut buf: Vec<IdxSize>;32if n == len {33buf = (0..len as IdxSize).collect();34if shuffle {35buf.shuffle(&mut rng)36}37} else {38// TODO: avoid extra potential copy by vendoring rand::seq::index::sample,39// or genericize take over slices over any unsigned type. The optimizer40// should get rid of the extra copy already if IdxSize matches the IndexVec41// size returned.42buf = match rand::seq::index::sample(&mut rng, len, n) {43IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),44#[cfg(target_pointer_width = "64")]45IndexVec::U64(v) => v.into_iter().map(|x| x as IdxSize).collect(),46};47}48IdxCa::new_vec(PlSmallStr::EMPTY, buf)49}5051impl<T> ChunkedArray<T>52where53T: PolarsNumericType,54StandardUniform: Distribution<T::Native>,55{56pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {57let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));58(0..size)59.map(|_| {60if rng.random::<f32>() < null_density {61None62} else {63Some(rng.random())64}65})66.collect()67}68}6970fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {71polars_ensure!(72with_replacement || n <= len,73ShapeMismatch:74"cannot take a larger sample than the total population when `with_replacement=false`"75);76Ok(())77}7879impl Series {80pub fn sample_n(81&self,82n: usize,83with_replacement: bool,84shuffle: bool,85seed: Option<u64>,86) -> PolarsResult<Self> {87ensure_shape(n, self.len(), with_replacement)?;88if n == 0 {89return Ok(self.clear());90}91let len = self.len();9293match with_replacement {94true => {95let idx = create_rand_index_with_replacement(n, len, seed);96debug_assert_eq!(len, self.len());97// SAFETY: we know that we never go out of bounds.98unsafe { Ok(self.take_unchecked(&idx)) }99},100false => {101let idx = create_rand_index_no_replacement(n, len, seed, shuffle);102debug_assert_eq!(len, self.len());103// SAFETY: we know that we never go out of bounds.104unsafe { Ok(self.take_unchecked(&idx)) }105},106}107}108109/// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].110pub fn sample_frac(111&self,112frac: f64,113with_replacement: bool,114shuffle: bool,115seed: Option<u64>,116) -> PolarsResult<Self> {117let n = (self.len() as f64 * frac) as usize;118self.sample_n(n, with_replacement, shuffle, seed)119}120121pub fn shuffle(&self, seed: Option<u64>) -> Self {122let len = self.len();123let n = len;124let idx = create_rand_index_no_replacement(n, len, seed, true);125debug_assert_eq!(len, self.len());126// SAFETY: we know that we never go out of bounds.127unsafe { self.take_unchecked(&idx) }128}129}130131impl<T> ChunkedArray<T>132where133T: PolarsDataType,134ChunkedArray<T>: ChunkTake<IdxCa>,135{136/// Sample n datapoints from this [`ChunkedArray`].137pub fn sample_n(138&self,139n: usize,140with_replacement: bool,141shuffle: bool,142seed: Option<u64>,143) -> PolarsResult<Self> {144ensure_shape(n, self.len(), with_replacement)?;145let len = self.len();146147match with_replacement {148true => {149let idx = create_rand_index_with_replacement(n, len, seed);150debug_assert_eq!(len, self.len());151// SAFETY: we know that we never go out of bounds.152unsafe { Ok(self.take_unchecked(&idx)) }153},154false => {155let idx = create_rand_index_no_replacement(n, len, seed, shuffle);156debug_assert_eq!(len, self.len());157// SAFETY: we know that we never go out of bounds.158unsafe { Ok(self.take_unchecked(&idx)) }159},160}161}162163/// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].164pub fn sample_frac(165&self,166frac: f64,167with_replacement: bool,168shuffle: bool,169seed: Option<u64>,170) -> PolarsResult<Self> {171let n = (self.len() as f64 * frac) as usize;172self.sample_n(n, with_replacement, shuffle, seed)173}174}175176impl DataFrame {177/// Sample n datapoints from this [`DataFrame`].178pub fn sample_n(179&self,180n: &Series,181with_replacement: bool,182shuffle: bool,183seed: Option<u64>,184) -> PolarsResult<Self> {185polars_ensure!(186n.len() == 1,187ComputeError: "Sample size must be a single value."188);189190let n = n.cast(&IDX_DTYPE)?;191let n = n.idx()?;192193match n.get(0) {194Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),195None => Ok(self.clear()),196}197}198199pub fn sample_n_literal(200&self,201n: usize,202with_replacement: bool,203shuffle: bool,204seed: Option<u64>,205) -> PolarsResult<Self> {206ensure_shape(n, self.height(), with_replacement)?;207// All columns should used the same indices. So we first create the indices.208let idx = match with_replacement {209true => create_rand_index_with_replacement(n, self.height(), seed),210false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),211};212// SAFETY: the indices are within bounds.213Ok(unsafe { self.take_unchecked(&idx) })214}215216/// Sample a fraction between 0.0-1.0 of this [`DataFrame`].217pub fn sample_frac(218&self,219frac: &Series,220with_replacement: bool,221shuffle: bool,222seed: Option<u64>,223) -> PolarsResult<Self> {224polars_ensure!(225frac.len() == 1,226ComputeError: "Sample fraction must be a single value."227);228229let frac = frac.cast(&Float64)?;230let frac = frac.f64()?;231232match frac.get(0) {233Some(frac) => {234let n = (self.height() as f64 * frac) as usize;235self.sample_n_literal(n, with_replacement, shuffle, seed)236},237None => Ok(self.clear()),238}239}240}241242impl<T> ChunkedArray<T>243where244T: PolarsNumericType,245T::Native: Float,246{247/// Create [`ChunkedArray`] with samples from a Normal distribution.248pub fn rand_normal(249name: PlSmallStr,250length: usize,251mean: f64,252std_dev: f64,253) -> PolarsResult<Self> {254let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;255let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);256let mut rng = rand::rng();257for _ in 0..length {258let smpl = normal.sample(&mut rng);259let smpl = NumCast::from(smpl).unwrap();260builder.append_value(smpl)261}262Ok(builder.finish())263}264265/// Create [`ChunkedArray`] with samples from a Standard Normal distribution.266pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {267let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);268let mut rng = rand::rng();269for _ in 0..length {270let smpl: f64 = rng.sample(StandardNormal);271let smpl = NumCast::from(smpl).unwrap();272builder.append_value(smpl)273}274builder.finish()275}276277/// Create [`ChunkedArray`] with samples from a Uniform distribution.278pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {279let uniform = Uniform::new(low, high).unwrap();280let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);281let mut rng = rand::rng();282for _ in 0..length {283let smpl = uniform.sample(&mut rng);284let smpl = NumCast::from(smpl).unwrap();285builder.append_value(smpl)286}287builder.finish()288}289}290291impl BooleanChunked {292/// Create [`ChunkedArray`] with samples from a Bernoulli distribution.293pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {294let dist = Bernoulli::new(p).map_err(to_compute_err)?;295let mut rng = rand::rng();296let mut builder = BooleanChunkedBuilder::new(name, length);297for _ in 0..length {298let smpl = dist.sample(&mut rng);299builder.append_value(smpl)300}301Ok(builder.finish())302}303}304305#[cfg(test)]306mod test {307use super::*;308309#[test]310fn test_sample() {311let df = df![312"foo" => &[1, 2, 3, 4, 5]313]314.unwrap();315316// Default samples are random and don't require seeds.317assert!(318df.sample_n(319&Series::new(PlSmallStr::from_static("s"), &[3]),320false,321false,322None323)324.is_ok()325);326assert!(327df.sample_frac(328&Series::new(PlSmallStr::from_static("frac"), &[0.4]),329false,330false,331None332)333.is_ok()334);335// With seeding.336assert!(337df.sample_n(338&Series::new(PlSmallStr::from_static("s"), &[3]),339false,340false,341Some(0)342)343.is_ok()344);345assert!(346df.sample_frac(347&Series::new(PlSmallStr::from_static("frac"), &[0.4]),348false,349false,350Some(0)351)352.is_ok()353);354// Without replacement can not sample more than 100%.355assert!(356df.sample_frac(357&Series::new(PlSmallStr::from_static("frac"), &[2.0]),358false,359false,360Some(0)361)362.is_err()363);364assert!(365df.sample_n(366&Series::new(PlSmallStr::from_static("s"), &[3]),367true,368false,369Some(0)370)371.is_ok()372);373assert!(374df.sample_frac(375&Series::new(PlSmallStr::from_static("frac"), &[0.4]),376true,377false,378Some(0)379)380.is_ok()381);382// With replacement can sample more than 100%.383assert!(384df.sample_frac(385&Series::new(PlSmallStr::from_static("frac"), &[2.0]),386true,387false,388Some(0)389)390.is_ok()391);392}393}394395396