Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/functions/ext/arrow.py
469 views
1
#!/usr/bin/env python3
2
from io import BytesIO
3
from typing import Any
4
from typing import List
5
from typing import Optional
6
from typing import Tuple
7
8
try:
9
import numpy as np
10
has_numpy = True
11
except ImportError:
12
has_numpy = False
13
14
try:
15
import polars as pl
16
has_polars = True
17
except ImportError:
18
has_polars = False
19
20
try:
21
import pandas as pd
22
has_pandas = True
23
except ImportError:
24
has_pandas = False
25
26
try:
27
import pyarrow as pa
28
import pyarrow.feather
29
has_pyarrow = True
30
except ImportError:
31
has_pyarrow = False
32
33
34
def load(
35
colspec: List[Tuple[str, int]],
36
data: bytes,
37
) -> Tuple[List[int], List[Any]]:
38
'''
39
Convert bytes in rowdat_1 format into rows of data.
40
41
Parameters
42
----------
43
colspec : List[str]
44
An List of column data types
45
data : bytes
46
The data in Apache Feather format
47
48
Returns
49
-------
50
Tuple[List[int], List[Any]]
51
52
'''
53
if not has_pyarrow:
54
raise RuntimeError('pyarrow must be installed for this operation')
55
56
table = pa.feather.read_table(BytesIO(data))
57
row_ids = table.column(0).to_pylist()
58
rows = []
59
for row in table.to_pylist():
60
rows.append([row[c] for c in table.column_names[1:]])
61
return row_ids, rows
62
63
64
def _load_vectors(
65
colspec: List[Tuple[str, int]],
66
data: bytes,
67
) -> Tuple[
68
'pa.Array[pa.int64]',
69
List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_]']],
70
]:
71
'''
72
Convert bytes in rowdat_1 format into columns of data.
73
74
Parameters
75
----------
76
colspec : List[str]
77
An List of column data types
78
data : bytes
79
The data in Apache Feather format
80
81
Returns
82
-------
83
Tuple[List[int], List[Tuple[Any, Any]]]
84
85
'''
86
if not has_pyarrow:
87
raise RuntimeError('pyarrow must be installed for this operation')
88
89
table = pa.feather.read_table(BytesIO(data))
90
row_ids = table.column(0)
91
out = []
92
for i, col in enumerate(table.columns[1:]):
93
out.append((col, col.is_null()))
94
return row_ids, out
95
96
97
def load_pandas(
98
colspec: List[Tuple[str, int]],
99
data: bytes,
100
) -> Tuple[
101
'pd.Series[np.int64]',
102
List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']],
103
]:
104
'''
105
Convert bytes in rowdat_1 format into rows of data.
106
107
Parameters
108
----------
109
colspec : List[str]
110
An List of column data types
111
data : bytes
112
The data in Apache Feather format
113
114
Returns
115
-------
116
Tuple[pd.Series[int], List[Tuple[pd.Series[Any], pd.Series[bool]]]]
117
118
'''
119
if not has_pandas or not has_numpy:
120
raise RuntimeError('pandas must be installed for this operation')
121
122
row_ids, cols = _load_vectors(colspec, data)
123
index = row_ids.to_pandas()
124
125
return index, \
126
[
127
(
128
data.to_pandas().reindex(index),
129
mask.to_pandas().reindex(index),
130
)
131
for (data, mask), (name, dtype) in zip(cols, colspec)
132
]
133
134
135
def load_polars(
136
colspec: List[Tuple[str, int]],
137
data: bytes,
138
) -> Tuple[
139
'pl.Series[pl.Int64]',
140
List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']],
141
]:
142
'''
143
Convert bytes in Apache Feather format into rows of data.
144
145
Parameters
146
----------
147
colspec : List[str]
148
An List of column data types
149
data : bytes
150
The data in Apache Feather format
151
152
Returns
153
-------
154
Tuple[polars.Series[int], List[polars.Series[Any]]]
155
156
'''
157
if not has_polars:
158
raise RuntimeError('polars must be installed for this operation')
159
160
row_ids, cols = _load_vectors(colspec, data)
161
162
return (
163
pl.from_arrow(row_ids), # type: ignore
164
[
165
(
166
pl.from_arrow(data), # type: ignore
167
pl.from_arrow(mask), # type: ignore
168
)
169
for (data, mask), (name, dtype) in zip(cols, colspec)
170
],
171
)
172
173
174
def load_numpy(
175
colspec: List[Tuple[str, int]],
176
data: bytes,
177
) -> Tuple[
178
'np.typing.NDArray[np.int64]',
179
List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']],
180
]:
181
'''
182
Convert bytes in Apache Feather format into rows of data.
183
184
Parameters
185
----------
186
colspec : List[str]
187
An List of column data types
188
data : bytes
189
The data in Apache Feather format
190
191
Returns
192
-------
193
Tuple[np.ndarray[int], List[np.ndarray[Any]]]
194
195
'''
196
if not has_numpy:
197
raise RuntimeError('numpy must be installed for this operation')
198
199
row_ids, cols = _load_vectors(colspec, data)
200
201
return row_ids.to_numpy(), \
202
[
203
(
204
data.to_numpy(),
205
mask.to_numpy(),
206
)
207
for (data, mask), (name, dtype) in zip(cols, colspec)
208
]
209
210
211
def load_arrow(
212
colspec: List[Tuple[str, int]],
213
data: bytes,
214
) -> Tuple[
215
'pa.Array[pa.int64()]',
216
List[Tuple['pa.Array[Any]', 'pa.Array[pa.bool_()]']],
217
]:
218
'''
219
Convert bytes in Apache Feather format into rows of data.
220
221
Parameters
222
----------
223
colspec : List[str]
224
An List of column data types
225
data : bytes
226
The data in Apache Feather format
227
228
Returns
229
-------
230
Tuple[pyarrow.Array[int], List[pyarrow.Array[Any]]]
231
232
'''
233
if not has_pyarrow:
234
raise RuntimeError('pyarrow must be installed for this operation')
235
236
return _load_vectors(colspec, data)
237
238
239
def dump(
240
returns: List[int],
241
row_ids: List[int],
242
rows: List[List[Any]],
243
) -> bytes:
244
'''
245
Convert a list of lists of data into Apache Feather format.
246
247
Parameters
248
----------
249
returns : List[int]
250
The returned data type
251
row_ids : List[int]
252
The row IDs
253
rows : List[List[Any]]
254
The rows of data and masks to serialize
255
256
Returns
257
-------
258
bytes
259
260
'''
261
if not has_pyarrow:
262
raise RuntimeError('pyarrow must be installed for this operation')
263
264
if len(rows) == 0 or len(row_ids) == 0:
265
return BytesIO().getbuffer()
266
267
colnames = ['col{}'.format(x) for x in range(len(rows[0]))]
268
269
tbl = pa.Table.from_pylist([dict(list(zip(colnames, row))) for row in rows])
270
tbl = tbl.add_column(0, '__index__', pa.array(row_ids))
271
272
sink = pa.BufferOutputStream()
273
batches = tbl.to_batches()
274
with pa.ipc.new_file(sink, batches[0].schema) as writer:
275
for batch in batches:
276
writer.write_batch(batch)
277
return sink.getvalue()
278
279
280
def _dump_vectors(
281
returns: List[int],
282
row_ids: 'pa.Array[pa.int64]',
283
cols: List[Tuple['pa.Array[Any]', Optional['pa.Array[pa.bool_]']]],
284
) -> bytes:
285
'''
286
Convert a list of columns of data into Apache Feather format.
287
288
Parameters
289
----------
290
returns : List[int]
291
The returned data type
292
row_ids : List[int]
293
The row IDs
294
cols : List[Tuple[Any, Any]]
295
The rows of data and masks to serialize
296
297
Returns
298
-------
299
bytes
300
301
'''
302
if not has_pyarrow:
303
raise RuntimeError('pyarrow must be installed for this operation')
304
305
if len(cols) == 0 or len(row_ids) == 0:
306
return BytesIO().getbuffer()
307
308
tbl = pa.Table.from_arrays(
309
[pa.array(data, mask=mask) for data, mask in cols],
310
names=['col{}'.format(x) for x in range(len(cols))],
311
)
312
tbl = tbl.add_column(0, '__index__', row_ids)
313
314
sink = pa.BufferOutputStream()
315
batches = tbl.to_batches()
316
with pa.ipc.new_file(sink, batches[0].schema) as writer:
317
for batch in batches:
318
writer.write_batch(batch)
319
return sink.getvalue()
320
321
322
def dump_arrow(
323
returns: List[int],
324
row_ids: 'pa.Array[int]',
325
cols: List[Tuple['pa.Array[Any]', 'pa.Array[bool]']],
326
) -> bytes:
327
if not has_pyarrow:
328
raise RuntimeError('pyarrow must be installed for this operation')
329
330
return _dump_vectors(returns, row_ids, cols)
331
332
333
def dump_numpy(
334
returns: List[int],
335
row_ids: 'np.typing.NDArray[np.int64]',
336
cols: List[Tuple['np.typing.NDArray[Any]', 'np.typing.NDArray[np.bool_]']],
337
) -> bytes:
338
if not has_numpy:
339
raise RuntimeError('numpy must be installed for this operation')
340
341
return _dump_vectors(
342
returns,
343
pa.array(row_ids),
344
[(pa.array(x), pa.array(y) if y is not None else None) for x, y in cols],
345
)
346
347
348
def dump_pandas(
349
returns: List[int],
350
row_ids: 'pd.Series[np.int64]',
351
cols: List[Tuple['pd.Series[Any]', 'pd.Series[np.bool_]']],
352
) -> bytes:
353
if not has_pandas or not has_numpy:
354
raise RuntimeError('pandas must be installed for this operation')
355
356
return _dump_vectors(
357
returns,
358
pa.array(row_ids),
359
[(pa.array(x), pa.array(y) if y is not None else None) for x, y in cols],
360
)
361
362
363
def dump_polars(
364
returns: List[int],
365
row_ids: 'pl.Series[pl.Int64]',
366
cols: List[Tuple['pl.Series[Any]', 'pl.Series[pl.Boolean]']],
367
) -> bytes:
368
if not has_polars:
369
raise RuntimeError('polars must be installed for this operation')
370
371
return _dump_vectors(
372
returns,
373
row_ids.to_arrow(),
374
[(x.to_arrow(), y.to_arrow() if y is not None else None) for x, y in cols],
375
)
376
377