Path: blob/main/pyo3-polars/example/derive_expression/expression_lib/src/distances.rs
7884 views
use std::hash::Hash;12use arrow::array::PrimitiveArray;3use num::Float;4use polars::prelude::*;5use pyo3_polars::export::polars_core::utils::arrow::types::NativeType;6use pyo3_polars::export::polars_core::with_match_physical_integer_type;78#[allow(clippy::all)]9pub(super) fn naive_hamming_dist(a: &str, b: &str) -> u32 {10let x = a.as_bytes();11let y = b.as_bytes();12x.iter()13.zip(y)14.fold(0, |a, (b, c)| a + (*b ^ *c).count_ones() as u32)15}1617fn jacc_helper<T: NativeType + Hash + Eq>(a: &PrimitiveArray<T>, b: &PrimitiveArray<T>) -> f64 {18// convert to hashsets over Option<T>19let s1 = a.into_iter().collect::<PlHashSet<_>>();20let s2 = b.into_iter().collect::<PlHashSet<_>>();2122// count the number of intersections23let s3_len = s1.intersection(&s2).count();24// return similarity25s3_len as f64 / (s1.len() + s2.len() - s3_len) as f6426}2728#[allow(unexpected_cfgs)]29pub(super) fn naive_jaccard_sim(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> {30polars_ensure!(31a.inner_dtype() == b.inner_dtype(),32ComputeError: "inner data types don't match"33);34polars_ensure!(35a.inner_dtype().is_integer(),36ComputeError: "inner data types must be integer"37);38Ok(with_match_physical_integer_type!(a.inner_dtype(), |$T| {39polars::prelude::arity::binary_elementwise(a, b, |a, b| {40match (a, b) {41(Some(a), Some(b)) => {42let a = a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();43let b = b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();44Some(jacc_helper(a, b))45},46_ => None47}48})49}))50}5152fn haversine_elementwise<T: Float>(start_lat: T, start_long: T, end_lat: T, end_long: T) -> T {53let r_in_km = T::from(6371.0).unwrap();54let two = T::from(2.0).unwrap();55let one = T::one();5657let d_lat = (end_lat - start_lat).to_radians();58let d_lon = (end_long - start_long).to_radians();59let lat1 = (start_lat).to_radians();60let lat2 = (end_lat).to_radians();6162let a = ((d_lat / two).sin()) * ((d_lat / two).sin())63+ ((d_lon / two).sin()) * ((d_lon / two).sin()) * (lat1.cos()) * (lat2.cos());64let c = two * ((a.sqrt()).atan2((one - a).sqrt()));65r_in_km * c66}6768pub(super) fn naive_haversine<T>(69start_lat: &ChunkedArray<T>,70start_long: &ChunkedArray<T>,71end_lat: &ChunkedArray<T>,72end_long: &ChunkedArray<T>,73) -> PolarsResult<ChunkedArray<T>>74where75T: PolarsFloatType,76T::Native: Float,77{78let out: ChunkedArray<T> = start_lat79.iter()80.zip(start_long.iter())81.zip(end_lat.iter())82.zip(end_long.iter())83.map(|(((start_lat, start_long), end_lat), end_long)| {84let start_lat = start_lat?;85let start_long = start_long?;86let end_lat = end_lat?;87let end_long = end_long?;88Some(haversine_elementwise(89start_lat, start_long, end_lat, end_long,90))91})92.collect();9394Ok(out.with_name(start_lat.name().clone()))95}969798