Path: blob/main/singlestoredb/functions/ext/arrow.py
469 views
#!/usr/bin/env python31from io import BytesIO2from typing import Any3from typing import List4from typing import Optional5from typing import Tuple67try:8import numpy as np9has_numpy = True10except ImportError:11has_numpy = False1213try:14import polars as pl15has_polars = True16except ImportError:17has_polars = False1819try:20import pandas as pd21has_pandas = True22except ImportError:23has_pandas = False2425try:26import pyarrow as pa27import pyarrow.feather28has_pyarrow = True29except ImportError:30has_pyarrow = False313233def load(34colspec: List[Tuple[str, int]],35data: bytes,36) -> Tuple[List[int], List[Any]]:37'''38Convert bytes in rowdat_1 format into rows of data.3940Parameters41----------42colspec : List[str]43An List of column data types44data : bytes45The data in Apache Feather format4647Returns48-------49Tuple[List[int], List[Any]]5051'''52if not has_pyarrow:53raise RuntimeError('pyarrow must be installed for this operation')5455table = pa.feather.read_table(BytesIO(data))56row_ids = table.column(0).to_pylist()57rows = []58for row in table.to_pylist():59rows.append([row[c] for c in table.column_names[1:]])60return row_ids, rows616263def _load_vectors(64colspec: List[Tuple[str, int]],65data: bytes,66) -> Tuple[67'pa.Array[pa.int64]',68List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_]']],69]:70'''71Convert bytes in rowdat_1 format into columns of data.7273Parameters74----------75colspec : List[str]76An List of column data types77data : bytes78The data in Apache Feather format7980Returns81-------82Tuple[List[int], List[Tuple[Any, Any]]]8384'''85if not has_pyarrow:86raise RuntimeError('pyarrow must be installed for this operation')8788table = pa.feather.read_table(BytesIO(data))89row_ids = table.column(0)90out = []91for i, col in enumerate(table.columns[1:]):92out.append((col, col.is_null()))93return row_ids, out949596def load_pandas(97colspec: List[Tuple[str, int]],98data: bytes,99) -> Tuple[100'pd.Series[np.int64]',101List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']],102]:103'''104Convert bytes in rowdat_1 format into rows of data.105106Parameters107----------108colspec : List[str]109An List of column data types110data : bytes111The data in Apache Feather format112113Returns114-------115Tuple[pd.Series[int], List[Tuple[pd.Series[Any], pd.Series[bool]]]]116117'''118if not has_pandas or not has_numpy:119raise RuntimeError('pandas must be installed for this operation')120121row_ids, cols = _load_vectors(colspec, data)122index = row_ids.to_pandas()123124return index, \125[126(127data.to_pandas().reindex(index),128mask.to_pandas().reindex(index),129)130for (data, mask), (name, dtype) in zip(cols, colspec)131]132133134def load_polars(135colspec: List[Tuple[str, int]],136data: bytes,137) -> Tuple[138'pl.Series[pl.Int64]',139List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']],140]:141'''142Convert bytes in Apache Feather format into rows of data.143144Parameters145----------146colspec : List[str]147An List of column data types148data : bytes149The data in Apache Feather format150151Returns152-------153Tuple[polars.Series[int], List[polars.Series[Any]]]154155'''156if not has_polars:157raise RuntimeError('polars must be installed for this operation')158159row_ids, cols = _load_vectors(colspec, data)160161return (162pl.from_arrow(row_ids), # type: ignore163[164(165pl.from_arrow(data), # type: ignore166pl.from_arrow(mask), # type: ignore167)168for (data, mask), (name, dtype) in zip(cols, colspec)169],170)171172173def load_numpy(174colspec: List[Tuple[str, int]],175data: bytes,176) -> Tuple[177'np.typing.NDArray[np.int64]',178List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']],179]:180'''181Convert bytes in Apache Feather format into rows of data.182183Parameters184----------185colspec : List[str]186An List of column data types187data : bytes188The data in Apache Feather format189190Returns191-------192Tuple[np.ndarray[int], List[np.ndarray[Any]]]193194'''195if not has_numpy:196raise RuntimeError('numpy must be installed for this operation')197198row_ids, cols = _load_vectors(colspec, data)199200return row_ids.to_numpy(), \201[202(203data.to_numpy(),204mask.to_numpy(),205)206for (data, mask), (name, dtype) in zip(cols, colspec)207]208209210def load_arrow(211colspec: List[Tuple[str, int]],212data: bytes,213) -> Tuple[214'pa.Array[pa.int64()]',215List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_()]']],216]:217'''218Convert bytes in Apache Feather format into rows of data.219220Parameters221----------222colspec : List[str]223An List of column data types224data : bytes225The data in Apache Feather format226227Returns228-------229Tuple[pyarrow.Array[int], List[pyarrow.Array[Any]]]230231'''232if not has_pyarrow:233raise RuntimeError('pyarrow must be installed for this operation')234235return _load_vectors(colspec, data)236237238def dump(239returns: List[int],240row_ids: List[int],241rows: List[List[Any]],242) -> bytes:243'''244Convert a list of lists of data into Apache Feather format.245246Parameters247----------248returns : List[int]249The returned data type250row_ids : List[int]251The row IDs252rows : List[List[Any]]253The rows of data and masks to serialize254255Returns256-------257bytes258259'''260if not has_pyarrow:261raise RuntimeError('pyarrow must be installed for this operation')262263if len(rows) == 0 or len(row_ids) == 0:264return BytesIO().getbuffer()265266colnames = ['col{}'.format(x) for x in range(len(rows[0]))]267268tbl = pa.Table.from_pylist([dict(list(zip(colnames, row))) for row in rows])269tbl = tbl.add_column(0, '__index__', pa.array(row_ids))270271sink = pa.BufferOutputStream()272batches = tbl.to_batches()273with pa.ipc.new_file(sink, batches[0].schema) as writer:274for batch in batches:275writer.write_batch(batch)276return sink.getvalue()277278279def _dump_vectors(280returns: List[int],281row_ids: 'pa.Array[pa.int64]',282cols: List[Tuple['pa.Array[Any]', Optional['pa.Array[pa.bool_]']]],283) -> bytes:284'''285Convert a list of columns of data into Apache Feather format.286287Parameters288----------289returns : List[int]290The returned data type291row_ids : List[int]292The row IDs293cols : List[Tuple[Any, Any]]294The rows of data and masks to serialize295296Returns297-------298bytes299300'''301if not has_pyarrow:302raise RuntimeError('pyarrow must be installed for this operation')303304if len(cols) == 0 or len(row_ids) == 0:305return BytesIO().getbuffer()306307tbl = pa.Table.from_arrays(308[pa.array(data, mask=mask) for data, mask in cols],309names=['col{}'.format(x) for x in range(len(cols))],310)311tbl = tbl.add_column(0, '__index__', row_ids)312313sink = pa.BufferOutputStream()314batches = tbl.to_batches()315with pa.ipc.new_file(sink, batches[0].schema) as writer:316for batch in batches:317writer.write_batch(batch)318return sink.getvalue()319320321def dump_arrow(322returns: List[int],323row_ids: 'pa.Array[int]',324cols: List[Tuple['pa.Array[Any]', 'pa.Array[bool]']],325) -> bytes:326if not has_pyarrow:327raise RuntimeError('pyarrow must be installed for this operation')328329return _dump_vectors(returns, row_ids, cols)330331332def dump_numpy(333returns: List[int],334row_ids: 'np.typing.NDArray[np.int64]',335cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']],336) -> bytes:337if not has_numpy:338raise RuntimeError('numpy must be installed for this operation')339340return _dump_vectors(341returns,342pa.array(row_ids),343[(pa.array(x), pa.array(y) if y is not None else None) for x, y in cols],344)345346347def dump_pandas(348returns: List[int],349row_ids: 'pd.Series[np.int64]',350cols: List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']],351) -> bytes:352if not has_pandas or not has_numpy:353raise RuntimeError('pandas must be installed for this operation')354355return _dump_vectors(356returns,357pa.array(row_ids),358[(pa.array(x), pa.array(y) if y is not None else None) for x, y in cols],359)360361362def dump_polars(363returns: List[int],364row_ids: 'pl.Series[pl.Int64]',365cols: List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']],366) -> bytes:367if not has_polars:368raise RuntimeError('polars must be installed for this operation')369370return _dump_vectors(371returns,372row_ids.to_arrow(),373[(x.to_arrow(), y.to_arrow() if y is not None else None) for x, y in cols],374)375376377