Path: blob/main/crates/polars-expr/src/dispatch/range/linear_space.rs
7884 views
use arrow::temporal_conversions::MICROSECONDS_IN_DAY;1use polars_core::prelude::*;2use polars_ops::series::{ClosedInterval, new_linear_space_f32, new_linear_space_f64};34use super::utils::{build_nulls, ensure_items_contain_exactly_one_value};56const CAPACITY_FACTOR: usize = 5;78pub(super) fn linear_space(s: &[Column], closed: ClosedInterval) -> PolarsResult<Column> {9let start = &s[0];10let end = &s[1];11let num_samples = &s[2];12let name = start.name();1314ensure_items_contain_exactly_one_value(&[start, end], &["start", "end"])?;15polars_ensure!(16num_samples.len() == 1,17ComputeError: "`num_samples` must contain exactly one value, got {} values", num_samples.len()18);1920let start = start.get(0).unwrap();21let end = end.get(0).unwrap();22let num_samples = num_samples.get(0).unwrap();23let num_samples = num_samples24.extract::<u64>()25.ok_or(PolarsError::ComputeError(26format!("'num_samples' must be non-negative integer, got {num_samples}").into(),27))?;2829match (start.dtype(), end.dtype()) {30(DataType::Float32, DataType::Float32) => new_linear_space_f32(31start.extract::<f32>().unwrap(),32end.extract::<f32>().unwrap(),33num_samples,34closed,35name.clone(),36)37.map(|s| s.into_column()),38(mut dt, dt2) if dt.is_temporal() && dt == dt2 => {39let mut start = start.extract::<i64>().unwrap();40let mut end = end.extract::<i64>().unwrap();4142// A linear space of a Date produces a sequence of Datetimes, so we must upcast.43if dt == DataType::Date {44start *= MICROSECONDS_IN_DAY;45end *= MICROSECONDS_IN_DAY;46dt = DataType::Datetime(TimeUnit::Microseconds, None);47}48new_linear_space_f64(start as f64, end as f64, num_samples, closed, name.clone())49.map(|s| s.cast(&dt).unwrap().into_column())50},51(dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => {52Err(PolarsError::ComputeError(53format!("'start' and 'end' have incompatible dtypes, got {dt1:?} and {dt2:?}")54.into(),55))56},57(_, _) => new_linear_space_f64(58start.extract::<f64>().unwrap(),59end.extract::<f64>().unwrap(),60num_samples,61closed,62name.clone(),63)64.map(|s| s.into_column()),65}66}6768pub(super) fn linear_spaces(69s: &[Column],70closed: ClosedInterval,71array_width: Option<usize>,72) -> PolarsResult<Column> {73let start = &s[0];74let end = &s[1];7576let (num_samples, capacity_factor) = match array_width {77Some(ns) => {78// An array width is provided instead of a column of `num_sample`s.79let scalar = Scalar::new(DataType::UInt64, AnyValue::UInt64(ns as u64));80(&Column::new_scalar(PlSmallStr::EMPTY, scalar, 1), ns)81},82None => (&s[2], CAPACITY_FACTOR),83};84let name = start.name().clone();8586let num_samples = num_samples.strict_cast(&DataType::UInt64)?;87let num_samples = num_samples.u64()?;88let len = start.len().max(end.len()).max(num_samples.len());8990match (start.dtype(), end.dtype()) {91(DataType::Float32, DataType::Float32) => {92let mut builder = ListPrimitiveChunkedBuilder::<Float32Type>::new(93name,94len,95len * capacity_factor,96DataType::Float32,97);9899let linspace_impl =100|start,101end,102num_samples,103builder: &mut ListPrimitiveChunkedBuilder<Float32Type>| {104let ls =105new_linear_space_f32(start, end, num_samples, closed, PlSmallStr::EMPTY)?;106builder.append_slice(ls.cont_slice().unwrap());107Ok(())108};109110let start = start.f32()?;111let end = end.f32()?;112let out =113linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;114115let to_type = array_width.map_or_else(116|| DataType::List(Box::new(DataType::Float32)),117|width| DataType::Array(Box::new(DataType::Float32), width),118);119out.cast(&to_type)120},121(mut dt, dt2) if dt.is_temporal() && dt == dt2 => {122let mut start = start.to_physical_repr();123let mut end = end.to_physical_repr();124125// A linear space of a Date produces a sequence of Datetimes, so we must upcast.126if dt == &DataType::Date {127start = start.cast(&DataType::Int64)? * MICROSECONDS_IN_DAY;128end = end.cast(&DataType::Int64)? * MICROSECONDS_IN_DAY;129dt = &DataType::Datetime(TimeUnit::Microseconds, None);130}131132let start = start.cast(&DataType::Float64)?;133let start = start.f64()?;134let end = end.cast(&DataType::Float64)?;135let end = end.f64()?;136137let mut builder = ListPrimitiveChunkedBuilder::<Float64Type>::new(138name,139len,140len * capacity_factor,141DataType::Float64,142);143144let linspace_impl =145|start,146end,147num_samples,148builder: &mut ListPrimitiveChunkedBuilder<Float64Type>| {149let ls =150new_linear_space_f64(start, end, num_samples, closed, PlSmallStr::EMPTY)?;151builder.append_slice(ls.cont_slice().unwrap());152Ok(())153};154let out =155linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;156157let to_type = array_width.map_or_else(158|| DataType::List(Box::new(dt.clone())),159|width| DataType::Array(Box::new(dt.clone()), width),160);161out.cast(&to_type)162},163(dt1, dt2) if !dt1.is_primitive_numeric() || !dt2.is_primitive_numeric() => {164Err(PolarsError::ComputeError(165format!("'start' and 'end' have incompatible dtypes, got {dt1:?} and {dt2:?}")166.into(),167))168},169(_, _) => {170let start = start.strict_cast(&DataType::Float64)?;171let end = end.strict_cast(&DataType::Float64)?;172let start = start.f64()?;173let end = end.f64()?;174175let mut builder = ListPrimitiveChunkedBuilder::<Float64Type>::new(176name,177len,178len * capacity_factor,179DataType::Float64,180);181182let linspace_impl =183|start,184end,185num_samples,186builder: &mut ListPrimitiveChunkedBuilder<Float64Type>| {187let ls =188new_linear_space_f64(start, end, num_samples, closed, PlSmallStr::EMPTY)?;189builder.append_slice(ls.cont_slice().unwrap());190Ok(())191};192let out =193linear_spaces_impl_broadcast(start, end, num_samples, linspace_impl, &mut builder)?;194195let to_type = array_width.map_or_else(196|| DataType::List(Box::new(DataType::Float64)),197|width| DataType::Array(Box::new(DataType::Float64), width),198);199out.cast(&to_type)200},201}202}203204/// Create a ranges column from the given start/end columns and a range function.205pub(super) fn linear_spaces_impl_broadcast<T, F>(206start: &ChunkedArray<T>,207end: &ChunkedArray<T>,208num_samples: &UInt64Chunked,209linear_space_impl: F,210builder: &mut ListPrimitiveChunkedBuilder<T>,211) -> PolarsResult<Column>212where213T: PolarsFloatType,214F: Fn(T::Native, T::Native, u64, &mut ListPrimitiveChunkedBuilder<T>) -> PolarsResult<()>,215ListPrimitiveChunkedBuilder<T>: ListBuilderTrait,216{217match (start.len(), end.len(), num_samples.len()) {218(len_start, len_end, len_samples) if len_start == len_end && len_start == len_samples => {219// (n, n, n)220build_linear_spaces::<_, _, _, T, F>(221start.iter(),222end.iter(),223num_samples.iter(),224linear_space_impl,225builder,226)?;227},228// (1, n, n)229(1, len_end, len_samples) if len_end == len_samples => {230let start_value = start.get(0);231if start_value.is_some() {232build_linear_spaces::<_, _, _, T, F>(233std::iter::repeat(start_value),234end.iter(),235num_samples.iter(),236linear_space_impl,237builder,238)?239} else {240build_nulls(builder, len_end)241}242},243// (n, 1, n)244(len_start, 1, len_samples) if len_start == len_samples => {245let end_value = end.get(0);246if end_value.is_some() {247build_linear_spaces::<_, _, _, T, F>(248start.iter(),249std::iter::repeat(end_value),250num_samples.iter(),251linear_space_impl,252builder,253)?254} else {255build_nulls(builder, len_start)256}257},258// (n, n, 1)259(len_start, len_end, 1) if len_start == len_end => {260let num_samples_value = num_samples.get(0);261if num_samples_value.is_some() {262build_linear_spaces::<_, _, _, T, F>(263start.iter(),264end.iter(),265std::iter::repeat(num_samples_value),266linear_space_impl,267builder,268)?269} else {270build_nulls(builder, len_start)271}272},273// (n, 1, 1)274(len_start, 1, 1) => {275let end_value = end.get(0);276let num_samples_value = num_samples.get(0);277match (end_value, num_samples_value) {278(Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>(279start.iter(),280std::iter::repeat(end_value),281std::iter::repeat(num_samples_value),282linear_space_impl,283builder,284)?,285_ => build_nulls(builder, len_start),286}287},288// (1, n, 1)289(1, len_end, 1) => {290let start_value = start.get(0);291let num_samples_value = num_samples.get(0);292match (start_value, num_samples_value) {293(Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>(294std::iter::repeat(start_value),295end.iter(),296std::iter::repeat(num_samples_value),297linear_space_impl,298builder,299)?,300_ => build_nulls(builder, len_end),301}302},303// (1, 1, n)304(1, 1, len_num_samples) => {305let start_value = start.get(0);306let end_value = end.get(0);307match (start_value, end_value) {308(Some(_), Some(_)) => build_linear_spaces::<_, _, _, T, F>(309std::iter::repeat(start_value),310std::iter::repeat(end_value),311num_samples.iter(),312linear_space_impl,313builder,314)?,315_ => build_nulls(builder, len_num_samples),316}317},318(len_start, len_end, len_num_samples) => {319polars_bail!(320ComputeError:321"lengths of `start` ({}), `end` ({}), and `num_samples` ({}) do not match",322len_start, len_end, len_num_samples323)324},325};326let out = builder.finish().into_column();327Ok(out)328}329330/// Iterate over a start and end column and create a range for each entry.331fn build_linear_spaces<I, J, K, T, F>(332start: I,333end: J,334num_samples: K,335linear_space_impl: F,336builder: &mut ListPrimitiveChunkedBuilder<T>,337) -> PolarsResult<()>338where339I: Iterator<Item = Option<T::Native>>,340J: Iterator<Item = Option<T::Native>>,341K: Iterator<Item = Option<u64>>,342T: PolarsFloatType,343F: Fn(T::Native, T::Native, u64, &mut ListPrimitiveChunkedBuilder<T>) -> PolarsResult<()>,344ListPrimitiveChunkedBuilder<T>: ListBuilderTrait,345{346for ((start, end), num_samples) in start.zip(end).zip(num_samples) {347match (start, end, num_samples) {348(Some(start), Some(end), Some(num_samples)) => {349linear_space_impl(start, end, num_samples, builder)?350},351_ => builder.append_null(),352}353}354Ok(())355}356357358