Path: blob/main/singlestoredb/functions/utils.py
469 views
import dataclasses1import inspect2import struct3import sys4import types5import typing6from enum import Enum7from typing import Any8from typing import Dict9from typing import Iterable10from typing import Tuple11from typing import Union1213from .typing import Masked1415if sys.version_info >= (3, 10):16_UNION_TYPES = {typing.Union, types.UnionType}17else:18_UNION_TYPES = {typing.Union}192021is_dataclass = dataclasses.is_dataclass222324def is_masked(obj: Any) -> bool:25"""Check if an object is a Masked type."""26origin = typing.get_origin(obj)27if origin is not None:28return origin is Masked or \29(inspect.isclass(origin) and issubclass(origin, Masked))30return False313233def is_union(x: Any) -> bool:34"""Check if the object is a Union."""35return typing.get_origin(x) in _UNION_TYPES363738def get_annotations(obj: Any) -> Dict[str, Any]:39"""Get the annotations of an object."""40return typing.get_type_hints(obj)414243def get_module(obj: Any) -> str:44"""Get the module of an object."""45module = getattr(obj, '__module__', '').split('.')46if module:47return module[0]48return ''495051def get_type_name(obj: Any) -> str:52"""Get the type name of an object."""53if hasattr(obj, '__name__'):54return obj.__name__55if hasattr(obj, '__class__'):56return obj.__class__.__name__57return ''585960def is_numpy(obj: Any) -> bool:61"""Check if an object is a numpy array."""62if str(obj).startswith('numpy.ndarray['):63return True6465if inspect.isclass(obj):66if get_module(obj) == 'numpy':67return get_type_name(obj) == 'ndarray'6869origin = typing.get_origin(obj)70if get_module(origin) == 'numpy':71if get_type_name(origin) == 'ndarray':72return True7374dtype = type(obj)75if get_module(dtype) == 'numpy':76return get_type_name(dtype) == 'ndarray'7778return False798081def is_dataframe(obj: Any) -> bool:82"""Check if an object is a DataFrame."""83# Cheating here a bit so we don't have to import pandas / polars / pyarrow:84# unless we absolutely need to85if get_module(obj) == 'pandas':86return get_type_name(obj) == 'DataFrame'87if get_module(obj) == 'polars':88return get_type_name(obj) == 'DataFrame'89if get_module(obj) == 'pyarrow':90return get_type_name(obj) == 'Table'91return False929394def is_vector(obj: Any, include_masks: bool = False) -> bool:95"""Check if an object is a vector type."""96return is_pandas_series(obj) \97or is_polars_series(obj) \98or is_pyarrow_array(obj) \99or is_numpy(obj) \100or is_masked(obj)101102103def get_data_format(obj: Any) -> str:104"""Return the data format of the DataFrame / Table / vector."""105# Cheating here a bit so we don't have to import pandas / polars / pyarrow106# unless we absolutely need to107if get_module(obj) == 'pandas':108return 'pandas'109if get_module(obj) == 'polars':110return 'polars'111if get_module(obj) == 'pyarrow':112return 'arrow'113if get_module(obj) == 'numpy':114return 'numpy'115if isinstance(obj, list):116return 'list'117return 'scalar'118119120def is_pandas_series(obj: Any) -> bool:121"""Check if an object is a pandas Series."""122if is_union(obj):123obj = typing.get_args(obj)[0]124return (125get_module(obj) == 'pandas' and126get_type_name(obj) == 'Series'127)128129130def is_polars_series(obj: Any) -> bool:131"""Check if an object is a polars Series."""132if is_union(obj):133obj = typing.get_args(obj)[0]134return (135get_module(obj) == 'polars' and136get_type_name(obj) == 'Series'137)138139140def is_pyarrow_array(obj: Any) -> bool:141"""Check if an object is a pyarrow Array."""142if is_union(obj):143obj = typing.get_args(obj)[0]144return (145get_module(obj) == 'pyarrow' and146get_type_name(obj) == 'Array'147)148149150def is_typeddict(obj: Any) -> bool:151"""Check if an object is a TypedDict."""152if hasattr(typing, 'is_typeddict'):153return typing.is_typeddict(obj) # noqa: TYP006154return False155156157def is_namedtuple(obj: Any) -> bool:158"""Check if an object is a named tuple."""159if inspect.isclass(obj):160return (161issubclass(obj, tuple) and162hasattr(obj, '_asdict') and163hasattr(obj, '_fields')164)165return (166isinstance(obj, tuple) and167hasattr(obj, '_asdict') and168hasattr(obj, '_fields')169)170171172def is_pydantic(obj: Any) -> bool:173"""Check if an object is a pydantic model."""174if not inspect.isclass(obj):175return False176# We don't want to import pydantic here, so we check if177# the class is a subclass178return bool([179x for x in inspect.getmro(obj)180if get_module(x) == 'pydantic'181and get_type_name(x) == 'BaseModel'182])183184185class VectorTypes(str, Enum):186"""Enum for vector types."""187F16 = 'f16'188F32 = 'f32'189F64 = 'f64'190I8 = 'i8'191I16 = 'i16'192I32 = 'i32'193I64 = 'i64'194195196def _vector_type_to_numpy_type(197vector_type: VectorTypes,198) -> str:199"""Convert a vector type to a numpy type."""200if vector_type == VectorTypes.F32:201return 'f4'202elif vector_type == VectorTypes.F64:203return 'f8'204elif vector_type == VectorTypes.I8:205return 'i1'206elif vector_type == VectorTypes.I16:207return 'i2'208elif vector_type == VectorTypes.I32:209return 'i4'210elif vector_type == VectorTypes.I64:211return 'i8'212raise ValueError(f'unsupported element type: {vector_type}')213214215def _vector_type_to_struct_format(216vec: Any,217vector_type: VectorTypes,218) -> str:219"""Convert a vector type to a struct format string."""220n = len(vec)221if vector_type == VectorTypes.F32:222if isinstance(vec, (bytes, bytearray)):223n = n // 4224return f'<{n}f'225elif vector_type == VectorTypes.F64:226if isinstance(vec, (bytes, bytearray)):227n = n // 8228return f'<{n}d'229elif vector_type == VectorTypes.I8:230return f'<{n}b'231elif vector_type == VectorTypes.I16:232if isinstance(vec, (bytes, bytearray)):233n = n // 2234return f'<{n}h'235elif vector_type == VectorTypes.I32:236if isinstance(vec, (bytes, bytearray)):237n = n // 4238return f'<{n}i'239elif vector_type == VectorTypes.I64:240if isinstance(vec, (bytes, bytearray)):241n = n // 8242return f'<{n}q'243raise ValueError(f'unsupported element type: {vector_type}')244245246def unpack_vector(247obj: Union[bytes, bytearray],248vec_type: VectorTypes = VectorTypes.F32,249) -> Tuple[Any]:250"""251Unpack a vector from bytes.252253Parameters254----------255obj : bytes or bytearray256The object to unpack.257vec_type : VectorTypes258The type of the elements in the vector.259Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.260Default is 'f32'.261262Returns263-------264Tuple[Any]265The unpacked vector.266267"""268return struct.unpack(_vector_type_to_struct_format(obj, vec_type), obj)269270271def pack_vector(272obj: Any,273vec_type: VectorTypes = VectorTypes.F32,274) -> bytes:275"""276Pack a vector into bytes.277278Parameters279----------280obj : Any281The object to pack.282vec_type : VectorTypes283The type of the elements in the vector.284Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.285Default is 'f32'.286287Returns288-------289bytes290The packed vector.291292"""293if isinstance(obj, (list, tuple)):294return struct.pack(_vector_type_to_struct_format(obj, vec_type), *obj)295296if is_numpy(obj):297return obj.tobytes()298299if is_pandas_series(obj):300import pandas as pd301return pd.Series(obj).to_numpy().tobytes()302303if is_polars_series(obj):304import polars as pl305return pl.Series(obj).to_numpy().tobytes()306307if is_pyarrow_array(obj):308import pyarrow as pa309return pa.array(obj).to_numpy().tobytes()310311raise ValueError(312f'unsupported object type: {type(obj)}',313)314315316def unpack_vectors(317arr_of_vec: Any,318vec_type: VectorTypes = VectorTypes.F32,319) -> Iterable[Any]:320"""321Unpack a vector from an array of bytes.322323Parameters324----------325arr_of_vec : Iterable[Any]326The array of bytes to unpack.327vec_type : VectorTypes328The type of the elements in the vector.329Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.330Default is 'f32'.331332Returns333-------334Iterable[Any]335The unpacked vector.336337"""338if isinstance(arr_of_vec, (list, tuple)):339return [unpack_vector(x, vec_type) for x in arr_of_vec]340341import numpy as np342343dtype = _vector_type_to_numpy_type(vec_type)344345np_arr = np.array(346[np.frombuffer(x, dtype=dtype) for x in arr_of_vec],347dtype=dtype,348)349350if is_numpy(arr_of_vec):351return np_arr352353if is_pandas_series(arr_of_vec):354import pandas as pd355return pd.Series(np_arr)356357if is_polars_series(arr_of_vec):358import polars as pl359return pl.Series(np_arr)360361if is_pyarrow_array(arr_of_vec):362import pyarrow as pa363return pa.array(np_arr)364365raise ValueError(366f'unsupported object type: {type(arr_of_vec)}',367)368369370def pack_vectors(371arr_of_arr: Iterable[Any],372vec_type: VectorTypes = VectorTypes.F32,373) -> Iterable[Any]:374"""375Pack a vector into an array of bytes.376377Parameters378----------379arr_of_arr : Iterable[Any]380The array of bytes to pack.381vec_type : VectorTypes382The type of the elements in the vector.383Can be one of 'f32', 'f64', 'i8', 'i16', 'i32', or 'i64'.384Default is 'f32'.385386Returns387-------388Iterable[Any]389The array of packed vectors.390391"""392if isinstance(arr_of_arr, (list, tuple)):393if not arr_of_arr:394return []395fmt = _vector_type_to_struct_format(arr_of_arr[0], vec_type)396return [struct.pack(fmt, x) for x in arr_of_arr]397398import numpy as np399400# Use object type because numpy truncates nulls at the end of fixed binary401np_arr = np.array([x.tobytes() for x in arr_of_arr], dtype=np.object_)402403if is_numpy(arr_of_arr):404return np_arr405406if is_pandas_series(arr_of_arr):407import pandas as pd408return pd.Series(np_arr)409410if is_polars_series(arr_of_arr):411import polars as pl412return pl.Series(np_arr)413414if is_pyarrow_array(arr_of_arr):415import pyarrow as pa416return pa.array(np_arr)417418raise ValueError(419f'unsupported object type: {type(arr_of_arr)}',420)421422423