Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/gather/mod.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
// Licensed to the Apache Software Foundation (ASF) under one
3
// or more contributor license agreements. See the NOTICE file
4
// distributed with this work for additional information
5
// regarding copyright ownership. The ASF licenses this file
6
// to you under the Apache License, Version 2.0 (the
7
// "License"); you may not use this file except in compliance
8
// with the License. You may obtain a copy of the License at
9
//
10
// http://www.apache.org/licenses/LICENSE-2.0
11
//
12
// Unless required by applicable law or agreed to in writing,
13
// software distributed under the License is distributed on an
14
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
// KIND, either express or implied. See the License for the
16
// specific language governing permissions and limitations
17
// under the License.
18
19
//! Defines take kernel for [`Array`]
20
21
use arrow::array::{
22
self, Array, ArrayCollectIterExt, ArrayFromIterDtype, BinaryViewArray, NullArray, StaticArray,
23
Utf8ViewArray, new_empty_array,
24
};
25
use arrow::datatypes::{ArrowDataType, IdxArr};
26
use arrow::types::Index;
27
28
pub mod binary;
29
pub mod binview;
30
pub mod bitmap;
31
pub mod boolean;
32
pub mod fixed_size_list;
33
pub mod generic_binary;
34
pub mod list;
35
pub mod primitive;
36
pub mod structure;
37
pub mod sublist;
38
39
use arrow::with_match_primitive_type_full;
40
41
/// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls.
42
/// The returned array has a length equal to `indices.len()`.
43
/// # Safety
44
/// Doesn't do bound checks
45
pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box<dyn Array> {
46
if indices.len() == 0 {
47
return new_empty_array(values.dtype().clone());
48
}
49
50
use arrow::datatypes::PhysicalType::*;
51
match values.dtype().to_physical_type() {
52
Null => Box::new(NullArray::new(values.dtype().clone(), indices.len())),
53
Boolean => {
54
let values = values.as_any().downcast_ref().unwrap();
55
Box::new(boolean::take_unchecked(values, indices))
56
},
57
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
58
let values = values.as_any().downcast_ref().unwrap();
59
Box::new(primitive::take_primitive_unchecked::<$T>(&values, indices))
60
}),
61
LargeBinary => {
62
let values = values.as_any().downcast_ref().unwrap();
63
Box::new(binary::take_unchecked::<i64, _>(values, indices))
64
},
65
Struct => {
66
let array = values.as_any().downcast_ref().unwrap();
67
structure::take_unchecked(array, indices).boxed()
68
},
69
LargeList => {
70
let array = values.as_any().downcast_ref().unwrap();
71
Box::new(list::take_unchecked::<i64>(array, indices))
72
},
73
FixedSizeList => {
74
let array = values.as_any().downcast_ref().unwrap();
75
fixed_size_list::take_unchecked(array, indices)
76
},
77
BinaryView => {
78
let array: &BinaryViewArray = values.as_any().downcast_ref().unwrap();
79
binview::take_binview_unchecked(array, indices).boxed()
80
},
81
Utf8View => {
82
let array: &Utf8ViewArray = values.as_any().downcast_ref().unwrap();
83
binview::take_binview_unchecked(array, indices).boxed()
84
},
85
t => unimplemented!("Take not supported for data type {:?}", t),
86
}
87
}
88
89
/// Naive default implementation
90
unsafe fn take_unchecked_impl_generic<T>(
91
values: &T,
92
indices: &IdxArr,
93
new_null_func: &dyn Fn(ArrowDataType, usize) -> T,
94
) -> T
95
where
96
T: StaticArray + ArrayFromIterDtype<std::option::Option<Box<dyn array::Array>>>,
97
{
98
if values.null_count() == values.len() || indices.null_count() == indices.len() {
99
return new_null_func(values.dtype().clone(), indices.len());
100
}
101
102
match (indices.has_nulls(), values.has_nulls()) {
103
(true, true) => {
104
let values_validity = values.validity().unwrap();
105
106
indices
107
.iter()
108
.map(|i| {
109
if let Some(i) = i {
110
let i = *i as usize;
111
if values_validity.get_bit_unchecked(i) {
112
return Some(values.value_unchecked(i));
113
}
114
}
115
None
116
})
117
.collect_arr_trusted_with_dtype(values.dtype().clone())
118
},
119
(true, false) => indices
120
.iter()
121
.map(|i| {
122
if let Some(i) = i {
123
let i = *i as usize;
124
return Some(values.value_unchecked(i));
125
}
126
None
127
})
128
.collect_arr_trusted_with_dtype(values.dtype().clone()),
129
(false, true) => {
130
let values_validity = values.validity().unwrap();
131
132
indices
133
.values_iter()
134
.map(|i| {
135
let i = *i as usize;
136
if values_validity.get_bit_unchecked(i) {
137
return Some(values.value_unchecked(i));
138
}
139
None
140
})
141
.collect_arr_trusted_with_dtype(values.dtype().clone())
142
},
143
(false, false) => indices
144
.values_iter()
145
.map(|i| {
146
let i = *i as usize;
147
Some(values.value_unchecked(i))
148
})
149
.collect_arr_trusted_with_dtype(values.dtype().clone()),
150
}
151
}
152
153