Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/index_of.rs
6939 views
1
use arrow::array::{BinaryArray, BinaryViewArray, PrimitiveArray};
2
use polars_core::downcast_as_macro_arg_physical;
3
use polars_core::prelude::*;
4
use polars_utils::total_ord::TotalEq;
5
use row_encode::encode_rows_unordered;
6
7
/// Find the index of the value, or ``None`` if it can't be found.
8
fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray<DT>, value: AR::ValueT<'a>) -> Option<usize>
9
where
10
DT: PolarsDataType<Array = AR>,
11
AR: StaticArray,
12
AR::ValueT<'a>: TotalEq,
13
{
14
let req_value = &value;
15
let mut index = 0;
16
for chunk in ca.downcast_iter() {
17
if chunk.validity().is_some() {
18
for maybe_value in chunk.iter() {
19
if maybe_value.map(|v| v.tot_eq(req_value)) == Some(true) {
20
return Some(index);
21
} else {
22
index += 1;
23
}
24
}
25
} else {
26
// A lack of a validity bitmap means there are no nulls, so we
27
// can simplify our logic and use a faster code path:
28
for value in chunk.values_iter() {
29
if value.tot_eq(req_value) {
30
return Some(index);
31
} else {
32
index += 1;
33
}
34
}
35
}
36
}
37
None
38
}
39
40
fn index_of_numeric_value<T>(ca: &ChunkedArray<T>, value: T::Native) -> Option<usize>
41
where
42
T: PolarsNumericType,
43
{
44
index_of_value::<_, PrimitiveArray<T::Native>>(ca, value)
45
}
46
47
/// Try casting the value to the correct type, then call
48
/// index_of_numeric_value().
49
macro_rules! try_index_of_numeric_ca {
50
($ca:expr, $value:expr) => {{
51
let ca = $ca;
52
let value = $value;
53
// extract() returns None if casting failed, so consider an extract()
54
// failure as not finding the value. Nulls should have been handled
55
// earlier.
56
let value = value.into_value().to_physical().extract().unwrap();
57
index_of_numeric_value(ca, value)
58
}};
59
}
60
61
/// Find the index of a given value (the first and only entry in `value_series`)
62
/// within the series.
63
pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult<Option<usize>> {
64
polars_ensure!(
65
series.dtype() == needle.dtype(),
66
InvalidOperation: "Cannot perform index_of with mismatching datatypes: {:?} and {:?}",
67
series.dtype(),
68
needle.dtype(),
69
);
70
71
if series.is_empty() {
72
return Ok(None);
73
}
74
75
// Series is not null, and the value is null:
76
if needle.is_null() {
77
let null_count = series.null_count();
78
if null_count == 0 {
79
return Ok(None);
80
} else if null_count == series.len() {
81
return Ok(Some(0));
82
}
83
84
let mut offset = 0;
85
for chunk in series.chunks() {
86
let length = chunk.len();
87
if let Some(bitmap) = chunk.validity() {
88
let leading_ones = bitmap.leading_ones();
89
if leading_ones < length {
90
return Ok(Some(offset + leading_ones));
91
}
92
}
93
offset += length;
94
}
95
return Ok(None);
96
}
97
98
use DataType as DT;
99
match series.dtype().to_physical() {
100
DT::Null => unreachable!("handled above"),
101
DT::Boolean => Ok(if needle.value().extract_bool().unwrap() {
102
series.bool().unwrap().first_true_idx()
103
} else {
104
series.bool().unwrap().first_false_idx()
105
}),
106
dt if dt.is_primitive_numeric() => {
107
let series = series.to_physical_repr();
108
Ok(downcast_as_macro_arg_physical!(
109
series,
110
try_index_of_numeric_ca,
111
needle
112
))
113
},
114
DT::String => Ok(index_of_value::<_, BinaryViewArray>(
115
&series.str()?.as_binary(),
116
needle.value().extract_str().unwrap().as_bytes(),
117
)),
118
DT::Binary => Ok(index_of_value::<_, BinaryViewArray>(
119
series.binary()?,
120
needle.value().extract_bytes().unwrap(),
121
)),
122
DT::BinaryOffset => Ok(index_of_value::<_, BinaryArray<i64>>(
123
series.binary_offset()?,
124
needle.value().extract_bytes().unwrap(),
125
)),
126
DT::Array(_, _) | DT::List(_) | DT::Struct(_) => {
127
// For non-numeric dtypes, we convert to row-encoding, which essentially has
128
// us searching the physical representation of the data as a series of
129
// bytes.
130
let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1);
131
let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?;
132
let value = value_as_row_encoded_ca
133
.first()
134
.expect("Shouldn't have nulls in a row-encoded result");
135
let ca = encode_rows_unordered(&[series.clone().into_column()])?;
136
Ok(index_of_value::<_, BinaryArray<i64>>(&ca, value))
137
},
138
139
DT::UInt8
140
| DT::UInt16
141
| DT::UInt32
142
| DT::UInt64
143
| DT::Int8
144
| DT::Int16
145
| DT::Int32
146
| DT::Int64
147
| DT::Int128
148
| DT::Float32
149
| DT::Float64 => unreachable!("primitive numeric"),
150
151
// to_physical
152
#[cfg(feature = "dtype-decimal")]
153
DT::Decimal(..) => unreachable!(),
154
#[cfg(feature = "dtype-categorical")]
155
DT::Categorical(..) | DT::Enum(..) => unreachable!(),
156
DT::Date | DT::Datetime(..) | DT::Duration(..) | DT::Time => unreachable!(),
157
158
#[cfg(feature = "object")]
159
DT::Object(_) => polars_bail!(op = "index_of", series.dtype()),
160
161
DT::Unknown(_) => polars_bail!(op = "index_of", series.dtype()),
162
}
163
}
164
165