Path: blob/main/singlestoredb/functions/utils.py
799 views
import dataclasses1import inspect2import struct3import sys4import types5import typing6from collections.abc import Iterable7from enum import Enum8from typing import Any9from typing import Dict10from typing import Tuple11from typing import Union1213from .typing import Masked1415if sys.version_info >= (3, 10):16_UNION_TYPES = {typing.Union, types.UnionType}17else:18_UNION_TYPES = {typing.Union}192021is_dataclass = dataclasses.is_dataclass222324def is_masked(obj: Any) -> bool:25"""Check if an object is a Masked type."""26origin = typing.get_origin(obj)27if origin is not None:28return origin is Masked or \29(inspect.isclass(origin) and issubclass(origin, Masked))30return False313233def is_union(x: Any) -> bool:34"""Check if the object is a Union."""35return typing.get_origin(x) in _UNION_TYPES363738def get_annotations(obj: Any) -> Dict[str, Any]:39"""Get the annotations of an object."""40return typing.get_type_hints(obj)414243def get_module(obj: Any) -> str:44"""Get the module of an object."""45module = getattr(obj, '__module__', '').split('.')46if module:47return module[0]48return ''495051def get_type_name(obj: Any) -> str:52"""Get the type name of an object."""53if hasattr(obj, '__name__'):54return obj.__name__55if hasattr(obj, '__class__'):56return obj.__class__.__name__57return ''585960def is_numpy(obj: Any) -> bool:61"""Check if an object is a numpy array."""62if str(obj).startswith('numpy.ndarray['):63return True6465if inspect.isclass(obj):66if get_module(obj) == 'numpy':67return get_type_name(obj) == 'ndarray'6869origin = typing.get_origin(obj)70if get_module(origin) == 'numpy':71if get_type_name(origin) == 'ndarray':72return True7374dtype = type(obj)75if get_module(dtype) == 'numpy':76return get_type_name(dtype) == 'ndarray'7778return False798081def is_dataframe(obj: Any) -> bool:82"""Check if an object is a DataFrame."""83# Cheating here a bit so we don't have to import pandas / polars / pyarrow:84# unless we absolutely need to85if get_module(obj) == 'pandas':86return get_type_name(obj) == 'DataFrame'87if get_module(obj) == 'polars':88return get_type_name(obj) == 'DataFrame'89if get_module(obj) == 'pyarrow':90return get_type_name(obj) == 'Table'91return False929394def is_vector(obj: Any, include_masks: bool = False) -> bool:95"""Check if an object is a vector type."""96return is_pandas_series(obj) \97or is_polars_series(obj) \98or is_pyarrow_array(obj) \99or is_numpy(obj) \100or is_masked(obj)101102103def get_data_format(obj: Any) -> str:104"""Return the data format of the DataFrame / Table / vector."""105# Cheating here a bit so we don't have to import pandas / polars / pyarrow106# unless we absolutely need to107if get_module(obj) == 'pandas':108return 'pandas'109if get_module(obj) == 'polars':110return 'polars'111if get_module(obj) == 'pyarrow':112return 'arrow'113if get_module(obj) == 'numpy':114return 'numpy'115if isinstance(obj, list):116return 'list'117return 'scalar'118119120def is_pandas_series(obj: Any) -> bool:121"""Check if an object is a pandas Series."""122if is_union(obj):123obj = typing.get_args(obj)[0]124return (125get_module(obj) == 'pandas' and126get_type_name(obj) == 'Series'127)128129130def is_polars_series(obj: Any) -> bool:131"""Check if an object is a polars Series."""132if is_union(obj):133obj = typing.get_args(obj)[0]134return (135get_module(obj) == 'polars' and136get_type_name(obj) == 'Series'137)138139140def is_pyarrow_array(obj: Any) -> bool:141"""Check if an object is a pyarrow Array."""142if is_union(obj):143obj = typing.get_args(obj)[0]144return (145get_module(obj) == 'pyarrow' and146get_type_name(obj) == 'Array'147)148149150def is_typeddict(obj: Any) -> bool:151"""Check if an object is a TypedDict."""152if hasattr(typing, 'is_typeddict'):153return typing.is_typeddict(obj) # noqa: TYP006154return False155156157def is_namedtuple(obj: Any) -> bool:158"""Check if an object is a named tuple."""159if inspect.isclass(obj):160return (161issubclass(obj, tuple) and162hasattr(obj, '_asdict') and163hasattr(obj, '_fields')164)165return (166isinstance(obj, tuple) and167hasattr(obj, '_asdict') and168hasattr(obj, '_fields')169)170171172def is_pydantic(obj: Any) -> bool:173"""Check if an object is a pydantic model."""174if not inspect.isclass(obj):175return False176# We don't want to import pydantic here, so we check if177# the class is a subclass178return bool([179x for x in inspect.getmro(obj)180if get_module(x) == 'pydantic'181and get_type_name(x) == 'BaseModel'182])183184185class VectorTypes(str, Enum):186"""Enum for vector types."""187F16 = 'f16'188F32 = 'f32'189F64 = 'f64'190I8 = 'i8'191I16 = 'i16'192I32 = 'i32'193I64 = 'i64'194195196def _vector_type_to_numpy_type(197vector_type: VectorTypes,198) -> str:199"""Convert a vector type to a numpy type."""200if vector_type == VectorTypes.F16:201return 'f2'202elif vector_type == VectorTypes.F32:203return 'f4'204elif vector_type == VectorTypes.F64:205return 'f8'206elif vector_type == VectorTypes.I8:207return 'i1'208elif vector_type == VectorTypes.I16:209return 'i2'210elif vector_type == VectorTypes.I32:211return 'i4'212elif vector_type == VectorTypes.I64:213return 'i8'214raise ValueError(f'unsupported element type: {vector_type}')215216217def _vector_type_to_struct_format(218vec: Any,219vector_type: VectorTypes,220) -> str:221"""Convert a vector type to a struct format string."""222n = len(vec)223if vector_type == VectorTypes.F16:224if isinstance(vec, (bytes, bytearray)):225n = n // 2226return f'<{n}e'227elif vector_type == VectorTypes.F32:228if isinstance(vec, (bytes, bytearray)):229n = n // 4230return f'<{n}f'231elif vector_type == VectorTypes.F64:232if isinstance(vec, (bytes, bytearray)):233n = n // 8234return f'<{n}d'235elif vector_type == VectorTypes.I8:236return f'<{n}b'237elif vector_type == VectorTypes.I16:238if isinstance(vec, (bytes, bytearray)):239n = n // 2240return f'<{n}h'241elif vector_type == VectorTypes.I32:242if isinstance(vec, (bytes, bytearray)):243n = n // 4244return f'<{n}i'245elif vector_type == VectorTypes.I64:246if isinstance(vec, (bytes, bytearray)):247n = n // 8248return f'<{n}q'249raise ValueError(f'unsupported element type: {vector_type}')250251252def unpack_vector(253obj: Union[bytes, bytearray],254vec_type: VectorTypes = VectorTypes.F32,255) -> Tuple[Any]:256"""257Unpack a vector from bytes.258259Parameters260----------261obj : bytes or bytearray262The object to unpack.263vec_type : VectorTypes264The type of the elements in the vector.265Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.266Default is 'f32'.267268Returns269-------270Tuple[Any]271The unpacked vector.272273"""274return struct.unpack(_vector_type_to_struct_format(obj, vec_type), obj)275276277def pack_vector(278obj: Any,279vec_type: VectorTypes = VectorTypes.F32,280) -> bytes:281"""282Pack a vector into bytes.283284Parameters285----------286obj : Any287The object to pack.288vec_type : VectorTypes289The type of the elements in the vector.290Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.291Default is 'f32'.292293Returns294-------295bytes296The packed vector.297298"""299if isinstance(obj, (list, tuple)):300return struct.pack(_vector_type_to_struct_format(obj, vec_type), *obj)301302if is_numpy(obj):303return obj.tobytes()304305if is_pandas_series(obj):306import pandas as pd307return pd.Series(obj).to_numpy().tobytes()308309if is_polars_series(obj):310import polars as pl311return pl.Series(obj).to_numpy().tobytes()312313if is_pyarrow_array(obj):314import pyarrow as pa315return pa.array(obj).to_numpy().tobytes()316317raise ValueError(318f'unsupported object type: {type(obj)}',319)320321322def unpack_vectors(323arr_of_vec: Any,324vec_type: VectorTypes = VectorTypes.F32,325) -> Iterable[Any]:326"""327Unpack a vector from an array of bytes.328329Parameters330----------331arr_of_vec : Iterable[Any]332The array of bytes to unpack.333vec_type : VectorTypes334The type of the elements in the vector.335Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.336Default is 'f32'.337338Returns339-------340Iterable[Any]341The unpacked vector.342343"""344if isinstance(arr_of_vec, (list, tuple)):345return [unpack_vector(x, vec_type) for x in arr_of_vec]346347import numpy as np348349dtype = _vector_type_to_numpy_type(vec_type)350351np_arr = np.array(352[np.frombuffer(x, dtype=dtype) for x in arr_of_vec],353dtype=dtype,354)355356if is_numpy(arr_of_vec):357return np_arr358359if is_pandas_series(arr_of_vec):360import pandas as pd361return pd.Series(np_arr)362363if is_polars_series(arr_of_vec):364import polars as pl365return pl.Series(np_arr)366367if is_pyarrow_array(arr_of_vec):368import pyarrow as pa369return pa.array(np_arr)370371raise ValueError(372f'unsupported object type: {type(arr_of_vec)}',373)374375376def pack_vectors(377arr_of_arr: Iterable[Any],378vec_type: VectorTypes = VectorTypes.F32,379) -> Iterable[Any]:380"""381Pack a vector into an array of bytes.382383Parameters384----------385arr_of_arr : Iterable[Any]386The array of bytes to pack.387vec_type : VectorTypes388The type of the elements in the vector.389Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.390Default is 'f32'.391392Returns393-------394Iterable[Any]395The array of packed vectors.396397"""398if isinstance(arr_of_arr, (list, tuple)):399if not arr_of_arr:400return []401fmt = _vector_type_to_struct_format(arr_of_arr[0], vec_type)402return [struct.pack(fmt, x) for x in arr_of_arr]403404import numpy as np405406# Use object type because numpy truncates nulls at the end of fixed binary407np_arr = np.array([x.tobytes() for x in arr_of_arr], dtype=np.object_)408409if is_numpy(arr_of_arr):410return np_arr411412if is_pandas_series(arr_of_arr):413import pandas as pd414return pd.Series(np_arr)415416if is_polars_series(arr_of_arr):417import polars as pl418return pl.Series(np_arr)419420if is_pyarrow_array(arr_of_arr):421import pyarrow as pa422return pa.array(np_arr)423424raise ValueError(425f'unsupported object type: {type(arr_of_arr)}',426)427428429