Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/functions/ext/asgi.py
798 views
1
#!/usr/bin/env python3
2
"""
3
Web application for SingleStoreDB external functions.
4
5
This module supplies a function that can create web apps intended for use
6
with the external function feature of SingleStoreDB. The application
7
function is a standard ASGI <https://asgi.readthedocs.io/en/latest/index.html>
8
request handler for use with servers such as Uvicorn <https://www.uvicorn.org>.
9
10
An external function web application can be created using the `create_app`
11
function. By default, the exported Python functions are specified by
12
environment variables starting with SINGLESTOREDB_EXT_FUNCTIONS. See the
13
documentation in `create_app` for the full syntax. If the application is
14
created in Python code rather than from the command-line, exported
15
functions can be specified in the parameters.
16
17
An example of starting a server is shown below.
18
19
Example
20
-------
21
> SINGLESTOREDB_EXT_FUNCTIONS='myfuncs.[percentile_90,percentile_95]' \
22
python3 -m singlestoredb.functions.ext.asgi
23
24
"""
25
import argparse
26
import asyncio
27
import contextvars
28
import dataclasses
29
import datetime
30
import functools
31
import importlib.util
32
import inspect
33
import io
34
import itertools
35
import json
36
import logging
37
import os
38
import re
39
import secrets
40
import sys
41
import tempfile
42
import textwrap
43
import threading
44
import time
45
import traceback
46
import typing
47
import urllib
48
import uuid
49
import zipfile
50
import zipimport
51
from collections.abc import Awaitable
52
from collections.abc import Iterable
53
from collections.abc import Sequence
54
from types import ModuleType
55
from typing import Any
56
from typing import Callable
57
from typing import Dict
58
from typing import List
59
from typing import Optional
60
from typing import Set
61
from typing import Tuple
62
from typing import Union
63
64
from . import arrow
65
from . import json as jdata
66
from . import rowdat_1
67
from . import utils
68
from ... import connection
69
from ... import manage_workspaces
70
from ...config import get_option
71
from ...mysql.constants import FIELD_TYPE as ft
72
from ..signature import get_signature
73
from ..signature import signature_to_sql
74
from ..typing import Masked
75
from ..typing import Table
76
from .timer import Timer
77
from singlestoredb.docstring.parser import parse
78
from singlestoredb.functions.dtypes import escape_name
79
80
try:
81
import cloudpickle
82
has_cloudpickle = True
83
except ImportError:
84
has_cloudpickle = False
85
86
try:
87
from pydantic import BaseModel
88
has_pydantic = True
89
except ImportError:
90
has_pydantic = False
91
92
93
logger = utils.get_logger('singlestoredb.functions.ext.asgi')
94
95
# If a number of processes is specified, create a pool of workers
96
num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0)))
97
if num_processes > 1:
98
try:
99
from ray.util.multiprocessing import Pool
100
except ImportError:
101
from multiprocessing import Pool
102
func_map = Pool(num_processes).starmap
103
else:
104
func_map = itertools.starmap
105
106
107
async def to_thread(
108
func: Any, /, *args: Any, **kwargs: Dict[str, Any],
109
) -> Any:
110
loop = asyncio.get_running_loop()
111
ctx = contextvars.copy_context()
112
func_call = functools.partial(ctx.run, func, *args, **kwargs)
113
return await loop.run_in_executor(None, func_call)
114
115
116
# Use negative values to indicate unsigned ints / binary data / usec time precision
117
rowdat_1_type_map = {
118
'bool': ft.LONGLONG,
119
'int8': ft.LONGLONG,
120
'int16': ft.LONGLONG,
121
'int32': ft.LONGLONG,
122
'int64': ft.LONGLONG,
123
'uint8': -ft.LONGLONG,
124
'uint16': -ft.LONGLONG,
125
'uint32': -ft.LONGLONG,
126
'uint64': -ft.LONGLONG,
127
'float32': ft.DOUBLE,
128
'float64': ft.DOUBLE,
129
'str': ft.STRING,
130
'bytes': -ft.STRING,
131
}
132
133
134
def get_func_names(funcs: str) -> List[Tuple[str, str]]:
135
"""
136
Parse all function names from string.
137
138
Parameters
139
----------
140
func_names : str
141
String containing one or more function names. The syntax is
142
as follows: [func-name-1@func-alias-1,func-name-2@func-alias-2,...].
143
The optional '@name' portion is an alias if you want the function
144
to be renamed.
145
146
Returns
147
-------
148
List[Tuple[str]] : a list of tuples containing the names and aliases
149
of each function.
150
151
"""
152
if funcs.startswith('['):
153
func_names = funcs.replace('[', '').replace(']', '').split(',')
154
func_names = [x.strip() for x in func_names]
155
else:
156
func_names = [funcs]
157
158
out = []
159
for name in func_names:
160
alias = name
161
if '@' in name:
162
name, alias = name.split('@', 1)
163
out.append((name, alias))
164
165
return out
166
167
168
def as_tuple(x: Any) -> Any:
169
"""Convert object to tuple."""
170
if has_pydantic and isinstance(x, BaseModel):
171
return tuple(x.model_dump().values())
172
if dataclasses.is_dataclass(x):
173
return dataclasses.astuple(x) # type: ignore
174
if isinstance(x, dict):
175
return tuple(x.values())
176
return tuple(x)
177
178
179
def as_list_of_tuples(x: Any) -> Any:
180
"""Convert object to a list of tuples."""
181
if isinstance(x, Table):
182
x = x[0]
183
if isinstance(x, (list, tuple)) and len(x) > 0:
184
if isinstance(x[0], (list, tuple)):
185
return x
186
if has_pydantic and isinstance(x[0], BaseModel):
187
return [tuple(y.model_dump().values()) for y in x]
188
if dataclasses.is_dataclass(x[0]):
189
return [dataclasses.astuple(y) for y in x]
190
if isinstance(x[0], dict):
191
return [tuple(y.values()) for y in x]
192
return [(y,) for y in x]
193
return x
194
195
196
def get_dataframe_columns(df: Any) -> List[Any]:
197
"""Return columns of data from a dataframe/table."""
198
if isinstance(df, Table):
199
if len(df) == 1:
200
df = df[0]
201
else:
202
return list(df)
203
204
if isinstance(df, Masked):
205
return [df]
206
207
if isinstance(df, tuple):
208
return list(df)
209
210
rtype = str(type(df)).lower()
211
if 'dataframe' in rtype:
212
return [df[x] for x in df.columns]
213
elif 'table' in rtype:
214
return df.columns
215
elif 'series' in rtype:
216
return [df]
217
elif 'array' in rtype:
218
return [df]
219
elif 'tuple' in rtype:
220
return list(df)
221
222
raise TypeError(
223
'Unsupported data type for dataframe columns: '
224
f'{rtype}',
225
)
226
227
228
def get_array_class(data_format: str) -> Callable[..., Any]:
229
"""
230
Get the array class for the current data format.
231
232
"""
233
if data_format == 'polars':
234
import polars as pl
235
array_cls = pl.Series
236
elif data_format == 'arrow':
237
import pyarrow as pa
238
array_cls = pa.array
239
elif data_format == 'pandas':
240
import pandas as pd
241
array_cls = pd.Series
242
else:
243
import numpy as np
244
array_cls = np.array
245
return array_cls
246
247
248
def get_masked_params(func: Callable[..., Any]) -> List[bool]:
249
"""
250
Get the list of masked parameters for the function.
251
252
Parameters
253
----------
254
func : Callable
255
The function to call as the endpoint
256
257
Returns
258
-------
259
List[bool]
260
Boolean list of masked parameters
261
262
"""
263
params = inspect.signature(func).parameters
264
return [typing.get_origin(x.annotation) is Masked for x in params.values()]
265
266
267
def build_tuple(x: Any) -> Any:
268
"""Convert object to tuple."""
269
return tuple(x) if isinstance(x, Masked) else (x, None)
270
271
272
def cancel_on_event(
273
cancel_event: threading.Event,
274
) -> None:
275
"""
276
Cancel the function call if the cancel event is set.
277
278
Parameters
279
----------
280
cancel_event : threading.Event
281
The event to check for cancellation
282
283
Raises
284
------
285
asyncio.CancelledError
286
If the cancel event is set
287
288
"""
289
if cancel_event.is_set():
290
task = asyncio.current_task()
291
if task is not None:
292
task.cancel()
293
raise asyncio.CancelledError(
294
'Function call was cancelled by client',
295
)
296
297
298
def build_udf_endpoint(
299
func: Callable[..., Any],
300
returns_data_format: str,
301
) -> Callable[..., Any]:
302
"""
303
Build a UDF endpoint for scalar / list types (row-based).
304
305
Parameters
306
----------
307
func : Callable
308
The function to call as the endpoint
309
returns_data_format : str
310
The format of the return values
311
312
Returns
313
-------
314
Callable
315
The function endpoint
316
317
"""
318
if returns_data_format in ['scalar', 'list']:
319
320
is_async = asyncio.iscoroutinefunction(func)
321
322
async def do_func(
323
cancel_event: threading.Event,
324
timer: Timer,
325
row_ids: Sequence[int],
326
rows: Sequence[Sequence[Any]],
327
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
328
'''Call function on given rows of data.'''
329
out = []
330
async with timer('call_function'):
331
for row in rows:
332
cancel_on_event(cancel_event)
333
if is_async:
334
out.append(await func(*row))
335
else:
336
out.append(func(*row))
337
return row_ids, list(zip(out))
338
339
return do_func
340
341
return build_vector_udf_endpoint(func, returns_data_format)
342
343
344
def build_vector_udf_endpoint(
345
func: Callable[..., Any],
346
returns_data_format: str,
347
) -> Callable[..., Any]:
348
"""
349
Build a UDF endpoint for vector formats (column-based).
350
351
Parameters
352
----------
353
func : Callable
354
The function to call as the endpoint
355
returns_data_format : str
356
The format of the return values
357
358
Returns
359
-------
360
Callable
361
The function endpoint
362
363
"""
364
masks = get_masked_params(func)
365
array_cls = get_array_class(returns_data_format)
366
is_async = asyncio.iscoroutinefunction(func)
367
368
async def do_func(
369
cancel_event: threading.Event,
370
timer: Timer,
371
row_ids: Sequence[int],
372
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
373
) -> Tuple[
374
Sequence[int],
375
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
376
]:
377
'''Call function on given columns of data.'''
378
row_ids = array_cls(row_ids)
379
380
# Call the function with `cols` as the function parameters
381
async with timer('call_function'):
382
if cols and cols[0]:
383
if is_async:
384
out = await func(*[x if m else x[0] for x, m in zip(cols, masks)])
385
else:
386
out = func(*[x if m else x[0] for x, m in zip(cols, masks)])
387
else:
388
if is_async:
389
out = await func()
390
else:
391
out = func()
392
393
cancel_on_event(cancel_event)
394
395
# Single masked value
396
if isinstance(out, Masked):
397
return row_ids, [tuple(out)]
398
399
# Multiple return values
400
if isinstance(out, tuple):
401
return row_ids, [build_tuple(x) for x in out]
402
403
# Single return value
404
return row_ids, [(out, None)]
405
406
return do_func
407
408
409
def build_tvf_endpoint(
410
func: Callable[..., Any],
411
returns_data_format: str,
412
) -> Callable[..., Any]:
413
"""
414
Build a TVF endpoint for scalar / list types (row-based).
415
416
Parameters
417
----------
418
func : Callable
419
The function to call as the endpoint
420
returns_data_format : str
421
The format of the return values
422
423
Returns
424
-------
425
Callable
426
The function endpoint
427
428
"""
429
if returns_data_format in ['scalar', 'list']:
430
431
is_async = asyncio.iscoroutinefunction(func)
432
433
async def do_func(
434
cancel_event: threading.Event,
435
timer: Timer,
436
row_ids: Sequence[int],
437
rows: Sequence[Sequence[Any]],
438
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
439
'''Call function on given rows of data.'''
440
out_ids: List[int] = []
441
out = []
442
# Call function on each row of data
443
async with timer('call_function'):
444
for i, row in zip(row_ids, rows):
445
cancel_on_event(cancel_event)
446
if is_async:
447
res = await func(*row)
448
else:
449
res = func(*row)
450
out.extend(as_list_of_tuples(res))
451
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
452
return out_ids, out
453
454
return do_func
455
456
return build_vector_tvf_endpoint(func, returns_data_format)
457
458
459
def build_vector_tvf_endpoint(
460
func: Callable[..., Any],
461
returns_data_format: str,
462
) -> Callable[..., Any]:
463
"""
464
Build a TVF endpoint for vector formats (column-based).
465
466
Parameters
467
----------
468
func : Callable
469
The function to call as the endpoint
470
returns_data_format : str
471
The format of the return values
472
473
Returns
474
-------
475
Callable
476
The function endpoint
477
478
"""
479
masks = get_masked_params(func)
480
array_cls = get_array_class(returns_data_format)
481
482
async def do_func(
483
cancel_event: threading.Event,
484
timer: Timer,
485
row_ids: Sequence[int],
486
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
487
) -> Tuple[
488
Sequence[int],
489
List[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
490
]:
491
'''Call function on given columns of data.'''
492
# NOTE: There is no way to determine which row ID belongs to
493
# each result row, so we just have to use the same
494
# row ID for all rows in the result.
495
496
is_async = asyncio.iscoroutinefunction(func)
497
498
# Call function on each column of data
499
async with timer('call_function'):
500
if cols and cols[0]:
501
if is_async:
502
func_res = await func(
503
*[x if m else x[0] for x, m in zip(cols, masks)],
504
)
505
else:
506
func_res = func(
507
*[x if m else x[0] for x, m in zip(cols, masks)],
508
)
509
else:
510
if is_async:
511
func_res = await func()
512
else:
513
func_res = func()
514
515
res = get_dataframe_columns(func_res)
516
517
cancel_on_event(cancel_event)
518
519
# Generate row IDs
520
if isinstance(res[0], Masked):
521
row_ids = array_cls([row_ids[0]] * len(res[0][0]))
522
else:
523
row_ids = array_cls([row_ids[0]] * len(res[0]))
524
525
return row_ids, [build_tuple(x) for x in res]
526
527
return do_func
528
529
530
def make_func(
531
name: str,
532
func: Callable[..., Any],
533
) -> Tuple[Callable[..., Any], Dict[str, Any]]:
534
"""
535
Make a function endpoint.
536
537
Parameters
538
----------
539
name : str
540
Name of the function to create
541
func : Callable
542
The function to call as the endpoint
543
database : str, optional
544
The database to use for the function definition
545
546
Returns
547
-------
548
(Callable, Dict[str, Any])
549
550
"""
551
info: Dict[str, Any] = {}
552
553
sig = get_signature(func, func_name=name)
554
555
function_type = sig.get('function_type', 'udf')
556
args_data_format = sig.get('args_data_format', 'scalar')
557
returns_data_format = sig.get('returns_data_format', 'scalar')
558
timeout = (
559
func._singlestoredb_attrs.get('timeout') or # type: ignore
560
get_option('external_function.timeout')
561
)
562
563
if function_type == 'tvf':
564
do_func = build_tvf_endpoint(func, returns_data_format)
565
else:
566
do_func = build_udf_endpoint(func, returns_data_format)
567
568
do_func.__name__ = name
569
do_func.__doc__ = func.__doc__
570
571
# Store signature for generating CREATE FUNCTION calls
572
info['signature'] = sig
573
574
# Set data format
575
info['args_data_format'] = args_data_format
576
info['returns_data_format'] = returns_data_format
577
578
# Set function type
579
info['function_type'] = function_type
580
581
# Set timeout
582
info['timeout'] = max(timeout, 1)
583
584
# Set async flag
585
info['is_async'] = asyncio.iscoroutinefunction(func)
586
587
# Setup argument types for rowdat_1 parser
588
colspec = []
589
for x in sig['args']:
590
dtype = x['dtype'].replace('?', '')
591
if dtype not in rowdat_1_type_map:
592
raise TypeError(f'no data type mapping for {dtype}')
593
colspec.append((x['name'], rowdat_1_type_map[dtype]))
594
info['colspec'] = colspec
595
596
# Setup return type
597
returns = []
598
for x in sig['returns']:
599
dtype = x['dtype'].replace('?', '')
600
if dtype not in rowdat_1_type_map:
601
raise TypeError(f'no data type mapping for {dtype}')
602
returns.append((x['name'], rowdat_1_type_map[dtype]))
603
info['returns'] = returns
604
605
return do_func, info
606
607
608
async def cancel_on_timeout(timeout: int) -> None:
609
"""Cancel request if it takes too long."""
610
await asyncio.sleep(timeout)
611
raise asyncio.CancelledError(
612
'Function call was cancelled due to timeout',
613
)
614
615
616
async def cancel_on_disconnect(
617
receive: Callable[..., Awaitable[Any]],
618
) -> None:
619
"""Cancel request if client disconnects."""
620
while True:
621
message = await receive()
622
if message.get('type', '') == 'http.disconnect':
623
raise asyncio.CancelledError(
624
'Function call was cancelled by client',
625
)
626
627
628
async def cancel_all_tasks(tasks: Iterable[asyncio.Task[Any]]) -> None:
629
"""Cancel all tasks."""
630
for task in tasks:
631
task.cancel()
632
await asyncio.gather(*tasks, return_exceptions=True)
633
634
635
def start_counter() -> float:
636
"""Start a timer and return the start time."""
637
return time.perf_counter()
638
639
640
def end_counter(start: float) -> float:
641
"""End a timer and return the elapsed time."""
642
return time.perf_counter() - start
643
644
645
class Application(object):
646
"""
647
Create an external function application.
648
649
If `functions` is None, the environment is searched for function
650
specifications in variables starting with `SINGLESTOREDB_EXT_FUNCTIONS`.
651
Any number of environment variables can be specified as long as they
652
have this prefix. The format of the environment variable value is the
653
same as for the `functions` parameter.
654
655
Parameters
656
----------
657
functions : str or Iterable[str], optional
658
Python functions are specified using a string format as follows:
659
* Single function : <pkg1>.<func1>
660
* Multiple functions : <pkg1>.[<func1-name,func2-name,...]
661
* Function aliases : <pkg1>.[<func1@alias1,func2@alias2,...]
662
* Multiple packages : <pkg1>.<func1>:<pkg2>.<func2>
663
app_mode : str, optional
664
The mode of operation for the application: remote, managed, or collocated
665
url : str, optional
666
The URL of the function API
667
data_format : str, optional
668
The format of the data rows: 'rowdat_1' or 'json'
669
data_version : str, optional
670
The version of the call format to expect: '1.0'
671
link_name : str, optional
672
The link name to use for the external function application. This is
673
only for pre-existing links, and can only be used without
674
``link_config`` and ``link_credentials``.
675
link_config : Dict[str, Any], optional
676
The CONFIG section of a LINK definition. This dictionary gets
677
converted to JSON for the CREATE LINK call.
678
link_credentials : Dict[str, Any], optional
679
The CREDENTIALS section of a LINK definition. This dictionary gets
680
converted to JSON for the CREATE LINK call.
681
name_prefix : str, optional
682
Prefix to add to function names when registering with the database
683
name_suffix : str, optional
684
Suffix to add to function names when registering with the database
685
function_database : str, optional
686
The database to use for external function definitions.
687
log_file : str, optional
688
File path to write logs to instead of console. If None, logs are
689
written to console. When specified, application logger handlers
690
are replaced with a file handler.
691
log_level : str, optional
692
Logging level for the application logger. Valid values are 'info',
693
'debug', 'warning', 'error'. Defaults to 'info'.
694
disable_metrics : bool, optional
695
Disable logging of function call metrics. Defaults to False.
696
app_name : str, optional
697
Name for the application instance. Used to create a logger-specific
698
name. If not provided, a random name will be generated.
699
700
"""
701
702
# Plain text response start
703
text_response_dict: Dict[str, Any] = dict(
704
type='http.response.start',
705
status=200,
706
headers=[(b'content-type', b'text/plain')],
707
)
708
709
# Error response start
710
error_response_dict: Dict[str, Any] = dict(
711
type='http.response.start',
712
status=500,
713
headers=[(b'content-type', b'text/plain')],
714
)
715
716
# Timeout response start
717
timeout_response_dict: Dict[str, Any] = dict(
718
type='http.response.start',
719
status=504,
720
headers=[(b'content-type', b'text/plain')],
721
)
722
723
# Cancel response start
724
cancel_response_dict: Dict[str, Any] = dict(
725
type='http.response.start',
726
status=503,
727
headers=[(b'content-type', b'text/plain')],
728
)
729
730
# JSON response start
731
json_response_dict: Dict[str, Any] = dict(
732
type='http.response.start',
733
status=200,
734
headers=[(b'content-type', b'application/json')],
735
)
736
737
# ROWDAT_1 response start
738
rowdat_1_response_dict: Dict[str, Any] = dict(
739
type='http.response.start',
740
status=200,
741
headers=[(b'content-type', b'x-application/rowdat_1')],
742
)
743
744
# Apache Arrow response start
745
arrow_response_dict: Dict[str, Any] = dict(
746
type='http.response.start',
747
status=200,
748
headers=[(b'content-type', b'application/vnd.apache.arrow.file')],
749
)
750
751
# Path not found response start
752
path_not_found_response_dict: Dict[str, Any] = dict(
753
type='http.response.start',
754
status=404,
755
)
756
757
# Response body template
758
body_response_dict: Dict[str, Any] = dict(
759
type='http.response.body',
760
)
761
762
# Data format + version handlers
763
handlers = {
764
(b'application/octet-stream', b'1.0', 'scalar'): dict(
765
load=rowdat_1.load,
766
dump=rowdat_1.dump,
767
response=rowdat_1_response_dict,
768
),
769
(b'application/octet-stream', b'1.0', 'list'): dict(
770
load=rowdat_1.load,
771
dump=rowdat_1.dump,
772
response=rowdat_1_response_dict,
773
),
774
(b'application/octet-stream', b'1.0', 'pandas'): dict(
775
load=rowdat_1.load_pandas,
776
dump=rowdat_1.dump_pandas,
777
response=rowdat_1_response_dict,
778
),
779
(b'application/octet-stream', b'1.0', 'numpy'): dict(
780
load=rowdat_1.load_numpy,
781
dump=rowdat_1.dump_numpy,
782
response=rowdat_1_response_dict,
783
),
784
(b'application/octet-stream', b'1.0', 'polars'): dict(
785
load=rowdat_1.load_polars,
786
dump=rowdat_1.dump_polars,
787
response=rowdat_1_response_dict,
788
),
789
(b'application/octet-stream', b'1.0', 'arrow'): dict(
790
load=rowdat_1.load_arrow,
791
dump=rowdat_1.dump_arrow,
792
response=rowdat_1_response_dict,
793
),
794
(b'application/json', b'1.0', 'scalar'): dict(
795
load=jdata.load,
796
dump=jdata.dump,
797
response=json_response_dict,
798
),
799
(b'application/json', b'1.0', 'list'): dict(
800
load=jdata.load,
801
dump=jdata.dump,
802
response=json_response_dict,
803
),
804
(b'application/json', b'1.0', 'pandas'): dict(
805
load=jdata.load_pandas,
806
dump=jdata.dump_pandas,
807
response=json_response_dict,
808
),
809
(b'application/json', b'1.0', 'numpy'): dict(
810
load=jdata.load_numpy,
811
dump=jdata.dump_numpy,
812
response=json_response_dict,
813
),
814
(b'application/json', b'1.0', 'polars'): dict(
815
load=jdata.load_polars,
816
dump=jdata.dump_polars,
817
response=json_response_dict,
818
),
819
(b'application/json', b'1.0', 'arrow'): dict(
820
load=jdata.load_arrow,
821
dump=jdata.dump_arrow,
822
response=json_response_dict,
823
),
824
(b'application/vnd.apache.arrow.file', b'1.0', 'scalar'): dict(
825
load=arrow.load,
826
dump=arrow.dump,
827
response=arrow_response_dict,
828
),
829
(b'application/vnd.apache.arrow.file', b'1.0', 'pandas'): dict(
830
load=arrow.load_pandas,
831
dump=arrow.dump_pandas,
832
response=arrow_response_dict,
833
),
834
(b'application/vnd.apache.arrow.file', b'1.0', 'numpy'): dict(
835
load=arrow.load_numpy,
836
dump=arrow.dump_numpy,
837
response=arrow_response_dict,
838
),
839
(b'application/vnd.apache.arrow.file', b'1.0', 'polars'): dict(
840
load=arrow.load_polars,
841
dump=arrow.dump_polars,
842
response=arrow_response_dict,
843
),
844
(b'application/vnd.apache.arrow.file', b'1.0', 'arrow'): dict(
845
load=arrow.load_arrow,
846
dump=arrow.dump_arrow,
847
response=arrow_response_dict,
848
),
849
}
850
851
# Valid URL paths
852
invoke_path = ('invoke',)
853
show_create_function_path = ('show', 'create_function')
854
show_function_info_path = ('show', 'function_info')
855
status = ('status',)
856
857
def __init__(
858
self,
859
functions: Optional[
860
Union[
861
str,
862
Iterable[str],
863
Callable[..., Any],
864
Iterable[Callable[..., Any]],
865
ModuleType,
866
Iterable[ModuleType],
867
]
868
] = None,
869
app_mode: str = get_option('external_function.app_mode'),
870
url: str = get_option('external_function.url'),
871
data_format: str = get_option('external_function.data_format'),
872
data_version: str = get_option('external_function.data_version'),
873
link_name: Optional[str] = get_option('external_function.link_name'),
874
link_config: Optional[Dict[str, Any]] = None,
875
link_credentials: Optional[Dict[str, Any]] = None,
876
name_prefix: str = get_option('external_function.name_prefix'),
877
name_suffix: str = get_option('external_function.name_suffix'),
878
function_database: Optional[str] = None,
879
log_file: Optional[str] = get_option('external_function.log_file'),
880
log_level: str = get_option('external_function.log_level'),
881
disable_metrics: bool = get_option('external_function.disable_metrics'),
882
app_name: Optional[str] = get_option('external_function.app_name'),
883
) -> None:
884
if link_name and (link_config or link_credentials):
885
raise ValueError(
886
'`link_name` can not be used with `link_config` or `link_credentials`',
887
)
888
889
if link_config is None:
890
link_config = json.loads(
891
get_option('external_function.link_config') or '{}',
892
) or None
893
894
if link_credentials is None:
895
link_credentials = json.loads(
896
get_option('external_function.link_credentials') or '{}',
897
) or None
898
899
# Generate application name if not provided
900
if app_name is None:
901
app_name = f'udf_app_{secrets.token_hex(4)}'
902
903
self.name = app_name
904
905
# Create logger instance specific to this application
906
self.logger = utils.get_logger(f'singlestoredb.functions.ext.asgi.{self.name}')
907
908
# List of functions specs
909
specs: List[Union[str, Callable[..., Any], ModuleType]] = []
910
911
# Look up Python function specifications
912
if functions is None:
913
env_vars = [
914
x for x in os.environ.keys()
915
if x.startswith('SINGLESTOREDB_EXT_FUNCTIONS')
916
]
917
if env_vars:
918
specs = [os.environ[x] for x in env_vars]
919
else:
920
import __main__
921
specs = [__main__]
922
923
elif isinstance(functions, ModuleType):
924
specs = [functions]
925
926
elif isinstance(functions, str):
927
specs = [functions]
928
929
elif callable(functions):
930
specs = [functions]
931
932
else:
933
specs = list(functions)
934
935
# Add functions to application
936
endpoints = dict()
937
external_functions = dict()
938
for funcs in itertools.chain(specs):
939
940
if isinstance(funcs, str):
941
# Module name
942
if importlib.util.find_spec(funcs) is not None:
943
items = importlib.import_module(funcs)
944
for x in vars(items).values():
945
if not hasattr(x, '_singlestoredb_attrs'):
946
continue
947
name = x._singlestoredb_attrs.get('name', x.__name__)
948
name = f'{name_prefix}{name}{name_suffix}'
949
external_functions[x.__name__] = x
950
func, info = make_func(name, x)
951
endpoints[name.encode('utf-8')] = func, info
952
953
# Fully qualified function name
954
elif '.' in funcs:
955
pkg_path, func_names = funcs.rsplit('.', 1)
956
pkg = importlib.import_module(pkg_path)
957
958
if pkg is None:
959
raise RuntimeError(f'Could not locate module: {pkg}')
960
961
# Add endpoint for each exported function
962
for name, alias in get_func_names(func_names):
963
item = getattr(pkg, name)
964
alias = f'{name_prefix}{name}{name_suffix}'
965
external_functions[name] = item
966
func, info = make_func(alias, item)
967
endpoints[alias.encode('utf-8')] = func, info
968
969
else:
970
raise RuntimeError(f'Could not locate module: {funcs}')
971
972
elif isinstance(funcs, ModuleType):
973
for x in vars(funcs).values():
974
if not hasattr(x, '_singlestoredb_attrs'):
975
continue
976
name = x._singlestoredb_attrs.get('name', x.__name__)
977
name = f'{name_prefix}{name}{name_suffix}'
978
external_functions[x.__name__] = x
979
func, info = make_func(name, x)
980
endpoints[name.encode('utf-8')] = func, info
981
982
else:
983
alias = funcs.__name__
984
external_functions[funcs.__name__] = funcs
985
alias = f'{name_prefix}{alias}{name_suffix}'
986
func, info = make_func(alias, funcs)
987
endpoints[alias.encode('utf-8')] = func, info
988
989
self.app_mode = app_mode
990
self.url = url
991
self.data_format = data_format
992
self.data_version = data_version
993
self.link_name = link_name
994
self.link_config = link_config
995
self.link_credentials = link_credentials
996
self.endpoints = endpoints
997
self.external_functions = external_functions
998
self.function_database = function_database
999
self.log_file = log_file
1000
self.log_level = log_level
1001
self.disable_metrics = disable_metrics
1002
1003
# Configure logging
1004
self._configure_logging()
1005
1006
def _configure_logging(self) -> None:
1007
"""Configure logging based on the log_file settings."""
1008
# Set logger level
1009
self.logger.setLevel(getattr(logging, self.log_level.upper()))
1010
1011
# Remove all existing handlers to ensure clean configuration
1012
self.logger.handlers.clear()
1013
1014
# Configure log file if specified
1015
if self.log_file:
1016
# Create file handler
1017
file_handler = logging.FileHandler(self.log_file)
1018
file_handler.setLevel(getattr(logging, self.log_level.upper()))
1019
1020
# Use JSON formatter for file logging
1021
formatter = utils.JSONFormatter()
1022
file_handler.setFormatter(formatter)
1023
1024
# Add the handler to the logger
1025
self.logger.addHandler(file_handler)
1026
else:
1027
# For console logging, create a new stream handler with JSON formatter
1028
console_handler = logging.StreamHandler()
1029
console_handler.setLevel(getattr(logging, self.log_level.upper()))
1030
console_handler.setFormatter(utils.JSONFormatter())
1031
self.logger.addHandler(console_handler)
1032
1033
# Prevent propagation to avoid duplicate or differently formatted messages
1034
self.logger.propagate = False
1035
1036
def get_uvicorn_log_config(self) -> Dict[str, Any]:
1037
"""
1038
Create uvicorn log config that matches the Application's logging format.
1039
1040
This method returns the log configuration used by uvicorn, allowing external
1041
users to match the logging format of the Application class.
1042
1043
Returns
1044
-------
1045
Dict[str, Any]
1046
Log configuration dictionary compatible with uvicorn's log_config parameter
1047
1048
"""
1049
log_config = {
1050
'version': 1,
1051
'disable_existing_loggers': False,
1052
'formatters': {
1053
'json': {
1054
'()': 'singlestoredb.functions.ext.utils.JSONFormatter',
1055
},
1056
},
1057
'handlers': {
1058
'default': {
1059
'class': (
1060
'logging.FileHandler' if self.log_file
1061
else 'logging.StreamHandler'
1062
),
1063
'formatter': 'json',
1064
},
1065
},
1066
'loggers': {
1067
'uvicorn': {
1068
'handlers': ['default'],
1069
'level': self.log_level.upper(),
1070
'propagate': False,
1071
},
1072
'uvicorn.error': {
1073
'handlers': ['default'],
1074
'level': self.log_level.upper(),
1075
'propagate': False,
1076
},
1077
'uvicorn.access': {
1078
'handlers': ['default'],
1079
'level': self.log_level.upper(),
1080
'propagate': False,
1081
},
1082
},
1083
}
1084
1085
# Add filename to file handler if log file is specified
1086
if self.log_file:
1087
log_config['handlers']['default']['filename'] = self.log_file # type: ignore
1088
1089
return log_config
1090
1091
async def __call__(
1092
self,
1093
scope: Dict[str, Any],
1094
receive: Callable[..., Awaitable[Any]],
1095
send: Callable[..., Awaitable[Any]],
1096
) -> None:
1097
'''
1098
Application request handler.
1099
1100
Parameters
1101
----------
1102
scope : dict
1103
ASGI request scope
1104
receive : Callable
1105
Function to receieve request information
1106
send : Callable
1107
Function to send response information
1108
1109
'''
1110
request_id = str(uuid.uuid4())
1111
1112
timer = Timer(
1113
app_name=self.name,
1114
id=request_id,
1115
timestamp=datetime.datetime.now(
1116
datetime.timezone.utc,
1117
).strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
1118
)
1119
call_timer = Timer(
1120
app_name=self.name,
1121
id=request_id,
1122
timestamp=datetime.datetime.now(
1123
datetime.timezone.utc,
1124
).strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
1125
)
1126
1127
if scope['type'] != 'http':
1128
raise ValueError(f"Expected HTTP scope, got {scope['type']}")
1129
1130
method = scope['method']
1131
path = tuple(x for x in scope['path'].split('/') if x)
1132
headers = dict(scope['headers'])
1133
1134
content_type = headers.get(
1135
b'content-type',
1136
b'application/octet-stream',
1137
)
1138
accepts = headers.get(b'accepts', content_type)
1139
func_name = headers.get(b's2-ef-name', b'')
1140
func_endpoint = self.endpoints.get(func_name)
1141
ignore_cancel = headers.get(b's2-ef-ignore-cancel', b'false') == b'true'
1142
1143
timer.metadata['function'] = func_name.decode('utf-8') if func_name else ''
1144
call_timer.metadata['function'] = timer.metadata['function']
1145
1146
func = None
1147
func_info: Dict[str, Any] = {}
1148
if func_endpoint is not None:
1149
func, func_info = func_endpoint
1150
1151
# Call the endpoint
1152
if method == 'POST' and func is not None and path == self.invoke_path:
1153
1154
self.logger.info(
1155
'Function call initiated',
1156
extra={
1157
'app_name': self.name,
1158
'request_id': request_id,
1159
'function_name': func_name.decode('utf-8'),
1160
'content_type': content_type.decode('utf-8'),
1161
'accepts': accepts.decode('utf-8'),
1162
},
1163
)
1164
1165
args_data_format = func_info['args_data_format']
1166
returns_data_format = func_info['returns_data_format']
1167
data = []
1168
more_body = True
1169
with timer('receive_data'):
1170
while more_body:
1171
request = await receive()
1172
if request.get('type', '') == 'http.disconnect':
1173
raise RuntimeError('client disconnected')
1174
data.append(request['body'])
1175
more_body = request.get('more_body', False)
1176
1177
data_version = headers.get(b's2-ef-version', b'')
1178
input_handler = self.handlers[(content_type, data_version, args_data_format)]
1179
output_handler = self.handlers[(accepts, data_version, returns_data_format)]
1180
1181
try:
1182
all_tasks = []
1183
result = []
1184
1185
cancel_event = threading.Event()
1186
1187
with timer('parse_input'):
1188
inputs = input_handler['load']( # type: ignore
1189
func_info['colspec'], b''.join(data),
1190
)
1191
1192
func_task = asyncio.create_task(
1193
func(cancel_event, call_timer, *inputs)
1194
if func_info['is_async']
1195
else to_thread(
1196
lambda: asyncio.run(
1197
func(cancel_event, call_timer, *inputs),
1198
),
1199
),
1200
)
1201
disconnect_task = asyncio.create_task(
1202
asyncio.sleep(int(1e9))
1203
if ignore_cancel else cancel_on_disconnect(receive),
1204
)
1205
timeout_task = asyncio.create_task(
1206
cancel_on_timeout(func_info['timeout']),
1207
)
1208
1209
all_tasks += [func_task, disconnect_task, timeout_task]
1210
1211
async with timer('function_wrapper'):
1212
done, pending = await asyncio.wait(
1213
all_tasks, return_when=asyncio.FIRST_COMPLETED,
1214
)
1215
1216
await cancel_all_tasks(pending)
1217
1218
for task in done:
1219
if task is disconnect_task:
1220
cancel_event.set()
1221
raise asyncio.CancelledError(
1222
'Function call was cancelled by client disconnect',
1223
)
1224
1225
elif task is timeout_task:
1226
cancel_event.set()
1227
raise asyncio.TimeoutError(
1228
'Function call was cancelled due to timeout',
1229
)
1230
1231
elif task is func_task:
1232
result.extend(task.result())
1233
1234
with timer('format_output'):
1235
body = output_handler['dump'](
1236
[x[1] for x in func_info['returns']], *result, # type: ignore
1237
)
1238
1239
await send(output_handler['response'])
1240
1241
except asyncio.TimeoutError:
1242
self.logger.exception(
1243
'Function call timeout',
1244
extra={
1245
'app_name': self.name,
1246
'request_id': request_id,
1247
'function_name': func_name.decode('utf-8'),
1248
'timeout': func_info['timeout'],
1249
},
1250
)
1251
body = (
1252
'TimeoutError: Function call timed out after ' +
1253
str(func_info['timeout']) +
1254
' seconds'
1255
).encode('utf-8')
1256
await send(self.timeout_response_dict)
1257
1258
except asyncio.CancelledError:
1259
self.logger.exception(
1260
'Function call cancelled',
1261
extra={
1262
'app_name': self.name,
1263
'request_id': request_id,
1264
'function_name': func_name.decode('utf-8'),
1265
},
1266
)
1267
body = b'CancelledError: Function call was cancelled'
1268
await send(self.cancel_response_dict)
1269
1270
except Exception as e:
1271
self.logger.exception(
1272
'Function call error',
1273
extra={
1274
'app_name': self.name,
1275
'request_id': request_id,
1276
'function_name': func_name.decode('utf-8'),
1277
'exception_type': type(e).__name__,
1278
},
1279
)
1280
msg = traceback.format_exc().strip().split(' File ')[-1]
1281
if msg.startswith('"/tmp/ipykernel_'):
1282
msg = 'Line ' + msg.split(', line ')[-1]
1283
else:
1284
msg = 'File ' + msg
1285
body = msg.encode('utf-8')
1286
await send(self.error_response_dict)
1287
1288
finally:
1289
await cancel_all_tasks(all_tasks)
1290
1291
# Handle api reflection
1292
elif method == 'GET' and path == self.show_create_function_path:
1293
host = headers.get(b'host', b'localhost:80')
1294
reflected_url = f'{scope["scheme"]}://{host.decode("utf-8")}/invoke'
1295
1296
syntax = []
1297
for key, (endpoint, endpoint_info) in self.endpoints.items():
1298
if not func_name or key == func_name:
1299
syntax.append(
1300
signature_to_sql(
1301
endpoint_info['signature'],
1302
url=self.url or reflected_url,
1303
data_format=self.data_format,
1304
database=self.function_database or None,
1305
),
1306
)
1307
body = '\n'.join(syntax).encode('utf-8')
1308
1309
await send(self.text_response_dict)
1310
1311
# Return function info
1312
elif method == 'GET' and (path == self.show_function_info_path or not path):
1313
functions = self.get_function_info()
1314
body = json.dumps(dict(functions=functions)).encode('utf-8')
1315
await send(self.text_response_dict)
1316
1317
# Return status
1318
elif method == 'GET' and path == self.status:
1319
body = json.dumps(dict(status='ok')).encode('utf-8')
1320
await send(self.text_response_dict)
1321
1322
# Path not found
1323
else:
1324
body = b''
1325
await send(self.path_not_found_response_dict)
1326
1327
# Send body
1328
with timer('send_response'):
1329
out = self.body_response_dict.copy()
1330
out['body'] = body
1331
await send(out)
1332
1333
for k, v in call_timer.metrics.items():
1334
timer.metrics[k] = v
1335
1336
if not self.disable_metrics:
1337
metrics = timer.finish()
1338
self.logger.info(
1339
'Function call metrics',
1340
extra={
1341
'app_name': self.name,
1342
'request_id': request_id,
1343
'function_name': timer.metadata.get('function', ''),
1344
'metrics': metrics,
1345
},
1346
)
1347
1348
def _create_link(
1349
self,
1350
config: Optional[Dict[str, Any]],
1351
credentials: Optional[Dict[str, Any]],
1352
) -> Tuple[str, str]:
1353
"""Generate CREATE LINK command."""
1354
if self.link_name:
1355
return self.link_name, ''
1356
1357
if not config and not credentials:
1358
return '', ''
1359
1360
link_name = f'py_ext_func_link_{secrets.token_hex(14)}'
1361
out = [f'CREATE LINK {link_name} AS HTTP']
1362
1363
if config:
1364
out.append(f"CONFIG '{json.dumps(config)}'")
1365
1366
if credentials:
1367
out.append(f"CREDENTIALS '{json.dumps(credentials)}'")
1368
1369
return link_name, ' '.join(out) + ';'
1370
1371
def _locate_app_functions(self, cur: Any) -> Tuple[Set[str], Set[str]]:
1372
"""Locate all current functions and links belonging to this app."""
1373
funcs, links = set(), set()
1374
if self.function_database:
1375
database_prefix = escape_name(self.function_database) + '.'
1376
cur.execute(f'SHOW FUNCTIONS IN {escape_name(self.function_database)}')
1377
else:
1378
database_prefix = ''
1379
cur.execute('SHOW FUNCTIONS')
1380
1381
for row in list(cur):
1382
name, ftype, link = row[0], row[1], row[-1]
1383
# Only look at external functions
1384
if 'external' not in ftype.lower():
1385
continue
1386
# See if function URL matches url
1387
cur.execute(f'SHOW CREATE FUNCTION {database_prefix}{escape_name(name)}')
1388
for fname, _, code, *_ in list(cur):
1389
m = re.search(r" (?:\w+) (?:SERVICE|MANAGED) '([^']+)'", code)
1390
if m and m.group(1) == self.url:
1391
funcs.add(f'{database_prefix}{escape_name(fname)}')
1392
if link and re.match(r'^py_ext_func_link_\S{14}$', link):
1393
links.add(link)
1394
1395
return funcs, links
1396
1397
def get_function_info(
1398
self,
1399
func_name: Optional[str] = None,
1400
) -> Dict[str, Any]:
1401
"""
1402
Return the functions and function signature information.
1403
1404
Returns
1405
-------
1406
Dict[str, Any]
1407
1408
"""
1409
functions = {}
1410
no_default = object()
1411
1412
# Generate CREATE FUNCTION SQL for each function using get_create_functions
1413
create_sqls = self.get_create_functions(replace=True)
1414
sql_map = {}
1415
for (_, info), sql in zip(self.endpoints.values(), create_sqls):
1416
sig = info['signature']
1417
sql_map[sig['name']] = sql
1418
1419
for key, (func, info) in self.endpoints.items():
1420
# Get info from docstring
1421
doc_summary = ''
1422
doc_long_description = ''
1423
doc_params = {}
1424
doc_returns = None
1425
doc_examples = []
1426
if func.__doc__:
1427
try:
1428
docs = parse(func.__doc__)
1429
doc_params = {p.arg_name: p for p in docs.params}
1430
doc_returns = docs.returns
1431
if not docs.short_description and docs.long_description:
1432
doc_summary = docs.long_description or ''
1433
else:
1434
doc_summary = docs.short_description or ''
1435
doc_long_description = docs.long_description or ''
1436
for ex in docs.examples:
1437
ex_dict: Dict[str, Any] = {
1438
'description': None,
1439
'code': None,
1440
'output': None,
1441
}
1442
if ex.description:
1443
ex_dict['description'] = ex.description
1444
if ex.snippet:
1445
code, output = [], []
1446
for line in ex.snippet.split('\n'):
1447
line = line.rstrip()
1448
if re.match(r'^(\w+>|>>>|\.\.\.)', line):
1449
code.append(line)
1450
else:
1451
output.append(line)
1452
ex_dict['code'] = '\n'.join(code) or None
1453
ex_dict['output'] = '\n'.join(output) or None
1454
if ex.post_snippet:
1455
ex_dict['postscript'] = ex.post_snippet
1456
doc_examples.append(ex_dict)
1457
1458
except Exception as e:
1459
self.logger.warning(
1460
'Could not parse docstring for function',
1461
extra={
1462
'app_name': self.name,
1463
'function_name': key.decode('utf-8'),
1464
'error': str(e),
1465
},
1466
)
1467
1468
if not func_name or key == func_name:
1469
sig = info['signature']
1470
args = []
1471
1472
# Function arguments
1473
for i, a in enumerate(sig.get('args', [])):
1474
name = a['name']
1475
dtype = a['dtype']
1476
nullable = '?' in dtype
1477
args.append(
1478
dict(
1479
name=name,
1480
dtype=dtype.replace('?', ''),
1481
nullable=nullable,
1482
description=(doc_params[name].description or '')
1483
if name in doc_params else '',
1484
),
1485
)
1486
if a.get('default', no_default) is not no_default:
1487
args[-1]['default'] = a['default']
1488
1489
# Return values
1490
ret = sig.get('returns', [])
1491
returns = []
1492
1493
for a in ret:
1494
dtype = a['dtype']
1495
nullable = '?' in dtype
1496
returns.append(
1497
dict(
1498
dtype=dtype.replace('?', ''),
1499
nullable=nullable,
1500
description=doc_returns.description
1501
if doc_returns else '',
1502
),
1503
)
1504
if a.get('name', None):
1505
returns[-1]['name'] = a['name']
1506
if a.get('default', no_default) is not no_default:
1507
returns[-1]['default'] = a['default']
1508
1509
sql = sql_map.get(sig['name'], '')
1510
functions[sig['name']] = dict(
1511
args=args,
1512
returns=returns,
1513
function_type=info['function_type'],
1514
sql_statement=sql,
1515
summary=doc_summary,
1516
long_description=doc_long_description,
1517
examples=doc_examples,
1518
)
1519
1520
return functions
1521
1522
def get_create_functions(
1523
self,
1524
replace: bool = False,
1525
) -> List[str]:
1526
"""
1527
Generate CREATE FUNCTION code for all functions.
1528
1529
Parameters
1530
----------
1531
replace : bool, optional
1532
Should existing functions be replaced?
1533
1534
Returns
1535
-------
1536
List[str]
1537
1538
"""
1539
if not self.endpoints:
1540
return []
1541
1542
out = []
1543
link = ''
1544
if self.app_mode.lower() == 'remote':
1545
link, link_str = self._create_link(self.link_config, self.link_credentials)
1546
if link and link_str:
1547
out.append(link_str)
1548
1549
for key, (endpoint, endpoint_info) in self.endpoints.items():
1550
out.append(
1551
signature_to_sql(
1552
endpoint_info['signature'],
1553
url=self.url,
1554
data_format=self.data_format,
1555
app_mode=self.app_mode,
1556
replace=replace,
1557
link=link or None,
1558
database=self.function_database or None,
1559
),
1560
)
1561
1562
return out
1563
1564
def register_functions(
1565
self,
1566
*connection_args: Any,
1567
replace: bool = False,
1568
**connection_kwargs: Any,
1569
) -> None:
1570
"""
1571
Register functions with the database.
1572
1573
Parameters
1574
----------
1575
*connection_args : Any
1576
Database connection parameters
1577
replace : bool, optional
1578
Should existing functions be replaced?
1579
**connection_kwargs : Any
1580
Database connection parameters
1581
1582
"""
1583
with connection.connect(*connection_args, **connection_kwargs) as conn:
1584
with conn.cursor() as cur:
1585
if replace:
1586
funcs, links = self._locate_app_functions(cur)
1587
for fname in funcs:
1588
cur.execute(f'DROP FUNCTION IF EXISTS {fname}')
1589
for link in links:
1590
cur.execute(f'DROP LINK {link}')
1591
for func in self.get_create_functions(replace=replace):
1592
cur.execute(func)
1593
1594
def drop_functions(
1595
self,
1596
*connection_args: Any,
1597
**connection_kwargs: Any,
1598
) -> None:
1599
"""
1600
Drop registered functions from database.
1601
1602
Parameters
1603
----------
1604
*connection_args : Any
1605
Database connection parameters
1606
**connection_kwargs : Any
1607
Database connection parameters
1608
1609
"""
1610
with connection.connect(*connection_args, **connection_kwargs) as conn:
1611
with conn.cursor() as cur:
1612
funcs, links = self._locate_app_functions(cur)
1613
for fname in funcs:
1614
cur.execute(f'DROP FUNCTION IF EXISTS {fname}')
1615
for link in links:
1616
cur.execute(f'DROP LINK {link}')
1617
1618
async def call(
1619
self,
1620
name: str,
1621
data_in: io.BytesIO,
1622
data_out: io.BytesIO,
1623
data_format: Optional[str] = None,
1624
data_version: Optional[str] = None,
1625
) -> None:
1626
"""
1627
Call a function in the application.
1628
1629
Parameters
1630
----------
1631
name : str
1632
Name of the function to call
1633
data_in : io.BytesIO
1634
The input data rows
1635
data_out : io.BytesIO
1636
The output data rows
1637
data_format : str, optional
1638
The format of the input and output data
1639
data_version : str, optional
1640
The version of the data format
1641
1642
"""
1643
data_format = data_format or self.data_format
1644
data_version = data_version or self.data_version
1645
1646
async def receive() -> Dict[str, Any]:
1647
return dict(body=data_in.read())
1648
1649
async def send(content: Dict[str, Any]) -> None:
1650
status = content.get('status', 200)
1651
if status != 200:
1652
raise KeyError(f'error occurred when calling `{name}`: {status}')
1653
data_out.write(content.get('body', b''))
1654
1655
accepts = dict(
1656
json=b'application/json',
1657
rowdat_1=b'application/octet-stream',
1658
arrow=b'application/vnd.apache.arrow.file',
1659
)
1660
1661
# Mock an ASGI scope
1662
scope = dict(
1663
type='http',
1664
path='invoke',
1665
method='POST',
1666
headers={
1667
b'content-type': accepts[data_format.lower()],
1668
b'accepts': accepts[data_format.lower()],
1669
b's2-ef-name': name.encode('utf-8'),
1670
b's2-ef-version': data_version.encode('utf-8'),
1671
b's2-ef-ignore-cancel': b'true',
1672
},
1673
)
1674
1675
await self(scope, receive, send)
1676
1677
def to_environment(
1678
self,
1679
name: str,
1680
destination: str = '.',
1681
version: Optional[str] = None,
1682
dependencies: Optional[List[str]] = None,
1683
authors: Optional[List[Dict[str, str]]] = None,
1684
maintainers: Optional[List[Dict[str, str]]] = None,
1685
description: Optional[str] = None,
1686
container_service: Optional[Dict[str, Any]] = None,
1687
external_function: Optional[Dict[str, Any]] = None,
1688
external_function_remote: Optional[Dict[str, Any]] = None,
1689
external_function_collocated: Optional[Dict[str, Any]] = None,
1690
overwrite: bool = False,
1691
) -> None:
1692
"""
1693
Convert application to an environment file.
1694
1695
Parameters
1696
----------
1697
name : str
1698
Name of the output environment
1699
destination : str, optional
1700
Location of the output file
1701
version : str, optional
1702
Version of the package
1703
dependencies : List[str], optional
1704
List of dependency specifications like in a requirements.txt file
1705
authors : List[Dict[str, Any]], optional
1706
Dictionaries of author information. Keys may include: email, name
1707
maintainers : List[Dict[str, Any]], optional
1708
Dictionaries of maintainer information. Keys may include: email, name
1709
description : str, optional
1710
Description of package
1711
container_service : Dict[str, Any], optional
1712
Container service specifications
1713
external_function : Dict[str, Any], optional
1714
External function specifications (applies to both remote and collocated)
1715
external_function_remote : Dict[str, Any], optional
1716
Remote external function specifications
1717
external_function_collocated : Dict[str, Any], optional
1718
Collocated external function specifications
1719
overwrite : bool, optional
1720
Should destination file be overwritten if it exists?
1721
1722
"""
1723
if not has_cloudpickle:
1724
raise RuntimeError('the cloudpicke package is required for this operation')
1725
1726
# Write to temporary location if a remote destination is specified
1727
tmpdir = None
1728
if destination.startswith('stage://'):
1729
tmpdir = tempfile.TemporaryDirectory()
1730
local_path = os.path.join(tmpdir.name, f'{name}.env')
1731
else:
1732
local_path = os.path.join(destination, f'{name}.env')
1733
if not overwrite and os.path.exists(local_path):
1734
raise OSError(f'path already exists: {local_path}')
1735
1736
with zipfile.ZipFile(local_path, mode='w') as z:
1737
# Write metadata
1738
z.writestr(
1739
'pyproject.toml', utils.to_toml({
1740
'project': dict(
1741
name=name,
1742
version=version,
1743
dependencies=dependencies,
1744
requires_python='== ' +
1745
'.'.join(str(x) for x in sys.version_info[:3]),
1746
authors=authors,
1747
maintainers=maintainers,
1748
description=description,
1749
),
1750
'tool.container-service': container_service,
1751
'tool.external-function': external_function,
1752
'tool.external-function.remote': external_function_remote,
1753
'tool.external-function.collocated': external_function_collocated,
1754
}),
1755
)
1756
1757
# Write Python package
1758
z.writestr(
1759
f'{name}/__init__.py',
1760
textwrap.dedent(f'''
1761
import pickle as _pkl
1762
globals().update(
1763
_pkl.loads({cloudpickle.dumps(self.external_functions)}),
1764
)
1765
__all__ = {list(self.external_functions.keys())}''').strip(),
1766
)
1767
1768
# Upload to Stage as needed
1769
if destination.startswith('stage://'):
1770
url = urllib.parse.urlparse(re.sub(r'/+$', r'', destination) + '/')
1771
if not url.path or url.path == '/':
1772
raise ValueError(f'no stage path was specified: {destination}')
1773
1774
mgr = manage_workspaces()
1775
if url.hostname:
1776
wsg = mgr.get_workspace_group(url.hostname)
1777
elif os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP'):
1778
wsg = mgr.get_workspace_group(
1779
os.environ['SINGLESTOREDB_WORKSPACE_GROUP'],
1780
)
1781
else:
1782
raise ValueError(f'no workspace group specified: {destination}')
1783
1784
# Make intermediate directories
1785
if url.path.count('/') > 1:
1786
wsg.stage.mkdirs(os.path.dirname(url.path))
1787
1788
wsg.stage.upload_file(
1789
local_path, url.path + f'{name}.env',
1790
overwrite=overwrite,
1791
)
1792
os.remove(local_path)
1793
1794
1795
def main(argv: Optional[List[str]] = None) -> None:
1796
"""
1797
Main program for HTTP-based Python UDFs
1798
1799
Parameters
1800
----------
1801
argv : List[str], optional
1802
List of command-line parameters
1803
1804
"""
1805
try:
1806
import uvicorn
1807
except ImportError:
1808
raise ImportError('the uvicorn package is required to run this command')
1809
1810
# Should we run in embedded mode (typically for Jupyter)
1811
try:
1812
asyncio.get_running_loop()
1813
use_async = True
1814
except RuntimeError:
1815
use_async = False
1816
1817
# Temporary directory for Stage environment files
1818
tmpdir = None
1819
1820
# Depending on whether we find an environment file specified, we
1821
# may have to process the command line twice.
1822
functions = []
1823
defaults: Dict[str, Any] = {}
1824
for i in range(2):
1825
1826
parser = argparse.ArgumentParser(
1827
prog='python -m singlestoredb.functions.ext.asgi',
1828
description='Run an HTTP-based Python UDF server',
1829
)
1830
parser.add_argument(
1831
'--url', metavar='url',
1832
default=defaults.get(
1833
'url',
1834
get_option('external_function.url'),
1835
),
1836
help='URL of the UDF server endpoint',
1837
)
1838
parser.add_argument(
1839
'--host', metavar='host',
1840
default=defaults.get(
1841
'host',
1842
get_option('external_function.host'),
1843
),
1844
help='bind socket to this host',
1845
)
1846
parser.add_argument(
1847
'--port', metavar='port', type=int,
1848
default=defaults.get(
1849
'port',
1850
get_option('external_function.port'),
1851
),
1852
help='bind socket to this port',
1853
)
1854
parser.add_argument(
1855
'--db', metavar='conn-str',
1856
default=defaults.get(
1857
'connection',
1858
get_option('external_function.connection'),
1859
),
1860
help='connection string to use for registering functions',
1861
)
1862
parser.add_argument(
1863
'--replace-existing', action='store_true',
1864
help='should existing functions of the same name '
1865
'in the database be replaced?',
1866
)
1867
parser.add_argument(
1868
'--data-format', metavar='format',
1869
default=defaults.get(
1870
'data_format',
1871
get_option('external_function.data_format'),
1872
),
1873
choices=['rowdat_1', 'json'],
1874
help='format of the data rows',
1875
)
1876
parser.add_argument(
1877
'--data-version', metavar='version',
1878
default=defaults.get(
1879
'data_version',
1880
get_option('external_function.data_version'),
1881
),
1882
help='version of the data row format',
1883
)
1884
parser.add_argument(
1885
'--link-name', metavar='name',
1886
default=defaults.get(
1887
'link_name',
1888
get_option('external_function.link_name'),
1889
) or '',
1890
help='name of the link to use for connections',
1891
)
1892
parser.add_argument(
1893
'--link-config', metavar='json',
1894
default=str(
1895
defaults.get(
1896
'link_config',
1897
get_option('external_function.link_config'),
1898
) or '{}',
1899
),
1900
help='link config in JSON format',
1901
)
1902
parser.add_argument(
1903
'--link-credentials', metavar='json',
1904
default=str(
1905
defaults.get(
1906
'link_credentials',
1907
get_option('external_function.link_credentials'),
1908
) or '{}',
1909
),
1910
help='link credentials in JSON format',
1911
)
1912
parser.add_argument(
1913
'--log-level', metavar='[info|debug|warning|error]',
1914
default=defaults.get(
1915
'log_level',
1916
get_option('external_function.log_level'),
1917
),
1918
help='logging level',
1919
)
1920
parser.add_argument(
1921
'--log-file', metavar='filepath',
1922
default=defaults.get(
1923
'log_file',
1924
get_option('external_function.log_file'),
1925
),
1926
help='File path to write logs to instead of console',
1927
)
1928
parser.add_argument(
1929
'--disable-metrics', action='store_true',
1930
default=defaults.get(
1931
'disable_metrics',
1932
get_option('external_function.disable_metrics'),
1933
),
1934
help='Disable logging of function call metrics',
1935
)
1936
parser.add_argument(
1937
'--name-prefix', metavar='name_prefix',
1938
default=defaults.get(
1939
'name_prefix',
1940
get_option('external_function.name_prefix'),
1941
),
1942
help='Prefix to add to function names',
1943
)
1944
parser.add_argument(
1945
'--name-suffix', metavar='name_suffix',
1946
default=defaults.get(
1947
'name_suffix',
1948
get_option('external_function.name_suffix'),
1949
),
1950
help='Suffix to add to function names',
1951
)
1952
parser.add_argument(
1953
'--function-database', metavar='function_database',
1954
default=defaults.get(
1955
'function_database',
1956
get_option('external_function.function_database'),
1957
),
1958
help='Database to use for the function definition',
1959
)
1960
parser.add_argument(
1961
'--app-name', metavar='app_name',
1962
default=defaults.get(
1963
'app_name',
1964
get_option('external_function.app_name'),
1965
),
1966
help='Name for the application instance',
1967
)
1968
parser.add_argument(
1969
'functions', metavar='module.or.func.path', nargs='*',
1970
help='functions or modules to export in UDF server',
1971
)
1972
1973
args = parser.parse_args(argv)
1974
1975
if i > 0:
1976
break
1977
1978
# Download Stage files as needed
1979
for i, f in enumerate(args.functions):
1980
if f.startswith('stage://'):
1981
url = urllib.parse.urlparse(f)
1982
if not url.path or url.path == '/':
1983
raise ValueError(f'no stage path was specified: {f}')
1984
if url.path.endswith('/'):
1985
raise ValueError(f'an environment file must be specified: {f}')
1986
1987
mgr = manage_workspaces()
1988
if url.hostname:
1989
wsg = mgr.get_workspace_group(url.hostname)
1990
elif os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP'):
1991
wsg = mgr.get_workspace_group(
1992
os.environ['SINGLESTOREDB_WORKSPACE_GROUP'],
1993
)
1994
else:
1995
raise ValueError(f'no workspace group specified: {f}')
1996
1997
if tmpdir is None:
1998
tmpdir = tempfile.TemporaryDirectory()
1999
2000
local_path = os.path.join(tmpdir.name, url.path.split('/')[-1])
2001
wsg.stage.download_file(url.path, local_path)
2002
args.functions[i] = local_path
2003
2004
elif f.startswith('http://') or f.startswith('https://'):
2005
if tmpdir is None:
2006
tmpdir = tempfile.TemporaryDirectory()
2007
2008
local_path = os.path.join(tmpdir.name, f.split('/')[-1])
2009
urllib.request.urlretrieve(f, local_path)
2010
args.functions[i] = local_path
2011
2012
# See if any of the args are zip files (assume they are environment files)
2013
modules = [(x, zipfile.is_zipfile(x)) for x in args.functions]
2014
envs = [x[0] for x in modules if x[1]]
2015
others = [x[0] for x in modules if not x[1]]
2016
2017
if envs and len(envs) > 1:
2018
raise RuntimeError('only one environment file may be specified')
2019
2020
if envs and others:
2021
raise RuntimeError('environment files and other modules can not be mixed.')
2022
2023
# See if an environment file was specified. If so, use those settings
2024
# as the defaults and reprocess command line.
2025
if envs:
2026
# Add pyproject.toml variables and redo command-line processing
2027
defaults = utils.read_config(
2028
envs[0],
2029
['tool.external-function', 'tool.external-function.remote'],
2030
)
2031
2032
# Load zip file as a module
2033
modname = os.path.splitext(os.path.basename(envs[0]))[0]
2034
zi = zipimport.zipimporter(envs[0])
2035
mod = zi.load_module(modname)
2036
if mod is None:
2037
raise RuntimeError(f'environment file could not be imported: {envs[0]}')
2038
functions = [mod]
2039
2040
if defaults:
2041
continue
2042
2043
args.functions = functions or args.functions or None
2044
args.replace_existing = args.replace_existing \
2045
or defaults.get('replace_existing') \
2046
or get_option('external_function.replace_existing')
2047
2048
# Substitute in host / port if specified
2049
if args.host != defaults.get('host') or args.port != defaults.get('port'):
2050
u = urllib.parse.urlparse(args.url)
2051
args.url = u._replace(netloc=f'{args.host}:{args.port}').geturl()
2052
2053
# Create application from functions / module
2054
app = Application(
2055
functions=args.functions,
2056
url=args.url,
2057
data_format=args.data_format,
2058
data_version=args.data_version,
2059
link_name=args.link_name or None,
2060
link_config=json.loads(args.link_config) or None,
2061
link_credentials=json.loads(args.link_credentials) or None,
2062
app_mode='remote',
2063
name_prefix=args.name_prefix,
2064
name_suffix=args.name_suffix,
2065
function_database=args.function_database or None,
2066
log_file=args.log_file,
2067
log_level=args.log_level,
2068
disable_metrics=args.disable_metrics,
2069
app_name=args.app_name,
2070
)
2071
2072
funcs = app.get_create_functions(replace=args.replace_existing)
2073
if not funcs:
2074
raise RuntimeError('no functions specified')
2075
2076
for f in funcs:
2077
app.logger.info(f)
2078
2079
try:
2080
if args.db:
2081
app.logger.info('Registering functions with database')
2082
app.register_functions(
2083
args.db,
2084
replace=args.replace_existing,
2085
)
2086
2087
app_args = {
2088
k: v for k, v in dict(
2089
host=args.host or None,
2090
port=args.port or None,
2091
log_level=args.log_level,
2092
lifespan='off',
2093
).items() if v is not None
2094
}
2095
2096
# Configure uvicorn logging to use JSON format matching Application's format
2097
app_args['log_config'] = app.get_uvicorn_log_config()
2098
2099
if use_async:
2100
asyncio.create_task(_run_uvicorn(uvicorn, app, app_args, db=args.db))
2101
else:
2102
uvicorn.run(app, **app_args)
2103
2104
finally:
2105
if not use_async and args.db:
2106
app.logger.info('Dropping functions from database')
2107
app.drop_functions(args.db)
2108
2109
2110
async def _run_uvicorn(
2111
uvicorn: Any,
2112
app: Any,
2113
app_args: Any,
2114
db: Optional[str] = None,
2115
) -> None:
2116
"""Run uvicorn server and clean up functions after shutdown."""
2117
await uvicorn.Server(uvicorn.Config(app, **app_args)).serve()
2118
if db:
2119
app.logger.info('Dropping functions from database')
2120
app.drop_functions(db)
2121
2122
2123
create_app = Application
2124
2125
2126
if __name__ == '__main__':
2127
try:
2128
main()
2129
except RuntimeError as exc:
2130
logger.error(str(exc))
2131
sys.exit(1)
2132
except KeyboardInterrupt:
2133
pass
2134
2135