Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/fusion/handler.py
801 views
1
#!/usr/bin/env python3
2
import abc
3
import functools
4
import os
5
import re
6
import sys
7
import textwrap
8
import warnings
9
from collections.abc import Iterable
10
from typing import Any
11
from typing import Callable
12
from typing import Dict
13
from typing import List
14
from typing import Optional
15
from typing import Set
16
from typing import Tuple
17
18
from parsimonious import Grammar
19
from parsimonious import ParseError
20
from parsimonious.nodes import Node
21
from parsimonious.nodes import NodeVisitor
22
23
from . import result
24
from ..connection import Connection
25
from ..warnings import PreviewFeatureWarning
26
27
CORE_GRAMMAR = r'''
28
ws = ~r"(\s+|(\s*/\*.*\*/\s*)+)"
29
qs = ~r"\"([^\"]*)\"|'([^\']*)'|([A-Za-z0-9_\-\.]+)|`([^\`]+)`" ws*
30
number = ~r"[-+]?(\d*\.)?\d+(e[-+]?\d+)?"i ws*
31
integer = ~r"-?\d+" ws*
32
comma = ws* "," ws*
33
eq = ws* "=" ws*
34
open_paren = ws* "(" ws*
35
close_paren = ws* ")" ws*
36
open_repeats = ws* ~r"[\(\[\{]" ws*
37
close_repeats = ws* ~r"[\)\]\}]" ws*
38
statement = ~r"[\s\S]*" ws*
39
table = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
40
column = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
41
link_name = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
42
catalog_name = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
43
44
json = ws* json_object ws*
45
json_object = ~r"{\s*" json_members? ~r"\s*}"
46
json_members = json_mapping (~r"\s*,\s*" json_mapping)*
47
json_mapping = json_string ~r"\s*:\s*" json_value
48
json_array = ~r"\[\s*" json_items? ~r"\s*\]"
49
json_items = json_value (~r"\s*,\s*" json_value)*
50
json_value = json_object / json_array / json_string / json_true_val / json_false_val / json_null_val / json_number
51
json_true_val = "true"
52
json_false_val = "false"
53
json_null_val = "null"
54
json_string = ~r"\"[ !#-\[\]-\U0010ffff]*(?:\\(?:[\"\\/bfnrt]|u[0-9A-Fa-f]{4})[ !#-\[\]-\U0010ffff]*)*\""
55
json_number = ~r"-?(0|[1-9][0-9]*)(\.\d*)?([eE][-+]?\d+)?"
56
''' # noqa: E501
57
58
BUILTINS = {
59
'<order-by>': r'''
60
order_by = ORDER BY order_by_key_,...
61
order_by_key_ = '<key>' [ ASC | DESC ]
62
''',
63
'<like>': r'''
64
like = LIKE '<pattern>'
65
''',
66
'<extended>': r'''
67
extended = EXTENDED
68
''',
69
'<limit>': r'''
70
limit = LIMIT <integer>
71
''',
72
'<integer>': '',
73
'<number>': '',
74
'<json>': '',
75
'<table>': '',
76
'<column>': '',
77
'<catalog-name>': '',
78
'<link-name>': '',
79
'<file-type>': r'''
80
file_type = { FILE | FOLDER }
81
''',
82
'<statement>': '',
83
}
84
85
BUILTIN_DEFAULTS = { # type: ignore
86
'order_by': {'by': []},
87
'like': None,
88
'extended': False,
89
'limit': None,
90
'json': {},
91
}
92
93
_json_unesc_re = re.compile(r'\\(["/\\bfnrt]|u[0-9A-Fa-f])')
94
_json_unesc_map = {
95
'"': '"',
96
'/': '/',
97
'\\': '\\',
98
'b': '\b',
99
'f': '\f',
100
'n': '\n',
101
'r': '\r',
102
't': '\t',
103
}
104
105
106
def _json_unescape(m: Any) -> str:
107
c = m.group(1)
108
if c[0] == 'u':
109
return chr(int(c[1:], 16))
110
c2 = _json_unesc_map.get(c)
111
if not c2:
112
raise ValueError(f'invalid escape sequence: {m.group(0)}')
113
return c2
114
115
116
def json_unescape(s: str) -> str:
117
return _json_unesc_re.sub(_json_unescape, s[1:-1])
118
119
120
def get_keywords(grammar: str) -> Tuple[str, ...]:
121
"""Return all all-caps words from the beginning of the line."""
122
m = re.match(r'^\s*((?:[@A-Z0-9_]+)(\s+|$|;))+', grammar)
123
if not m:
124
return tuple()
125
return tuple(re.split(r'\s+', m.group(0).replace(';', '').strip()))
126
127
128
def is_bool(grammar: str) -> bool:
129
"""Determine if the rule is a boolean."""
130
return bool(re.match(r'^[@A-Z0-9_\s*]+$', grammar))
131
132
133
def process_optional(m: Any) -> str:
134
"""Create options or groups of options."""
135
sql = m.group(1).strip()
136
if '|' in sql:
137
return f'( {sql} )*'
138
return f'( {sql} )?'
139
140
141
def process_alternates(m: Any) -> str:
142
"""Make alternates mandatory groups."""
143
sql = m.group(1).strip()
144
if '|' in sql:
145
return f'( {sql} )'
146
raise ValueError(f'alternates must contain "|": {sql}')
147
148
149
def process_repeats(m: Any) -> str:
150
"""Add repeated patterns."""
151
sql = m.group(1).strip()
152
return f'open_repeats? {sql} ws* ( comma {sql} ws* )* close_repeats?'
153
154
155
def lower_and_regex(m: Any) -> str:
156
"""Lowercase and convert literal to regex."""
157
start = m.group(1) or ''
158
sql = m.group(2)
159
return f'~"{start}{sql.lower()}"i'
160
161
162
def split_unions(grammar: str) -> str:
163
"""
164
Convert grammar in the form '[ x ] [ y ]' to '[ x | y ]'.
165
166
Parameters
167
----------
168
grammar : str
169
SQL grammar
170
171
Returns
172
-------
173
str
174
175
"""
176
in_alternate = False
177
out = []
178
for c in grammar:
179
if c == '{':
180
in_alternate = True
181
out.append(c)
182
elif c == '}':
183
in_alternate = False
184
out.append(c)
185
elif not in_alternate and c == '|':
186
out.append(']')
187
out.append(' ')
188
out.append('[')
189
else:
190
out.append(c)
191
return ''.join(out)
192
193
194
def expand_rules(rules: Dict[str, str], m: Any) -> str:
195
"""
196
Return expanded grammar syntax for given rule.
197
198
Parameters
199
----------
200
ops : Dict[str, str]
201
Dictionary of rules in grammar
202
203
Returns
204
-------
205
str
206
207
"""
208
txt = m.group(1)
209
if txt in rules:
210
return f' {rules[txt]} '
211
return f' <{txt}> '
212
213
214
def build_cmd(grammar: str) -> str:
215
"""Pre-process grammar to construct top-level command."""
216
if ';' not in grammar:
217
raise ValueError('a semi-colon exist at the end of the primary rule')
218
219
# Pre-space
220
m = re.match(r'^\s*', grammar)
221
space = m.group(0) if m else ''
222
223
# Split on ';' on a line by itself
224
begin, end = grammar.split(';', 1)
225
226
# Get statement keywords
227
keywords = get_keywords(begin)
228
cmd = '_'.join(x.lower() for x in keywords) + '_cmd'
229
230
# Collapse multi-line to one
231
begin = re.sub(r'\s+', r' ', begin)
232
233
return f'{space}{cmd} ={begin}\n{end}'
234
235
236
def build_syntax(grammar: str) -> str:
237
"""Construct full syntax."""
238
if ';' not in grammar:
239
raise ValueError('a semi-colon exist at the end of the primary rule')
240
241
# Split on ';' on a line by itself
242
cmd, end = grammar.split(';', 1)
243
244
name = ''
245
rules: Dict[str, Any] = {}
246
for line in end.split('\n'):
247
line = line.strip()
248
if line.startswith('&'):
249
rules[name] += '\n' + line
250
continue
251
if not line:
252
continue
253
name, value = line.split('=', 1)
254
name = name.strip()
255
value = value.strip()
256
rules[name] = value
257
258
while re.search(r' [a-z0-9_]+\b', cmd):
259
cmd = re.sub(r' ([a-z0-9_]+)\b', functools.partial(expand_rules, rules), cmd)
260
261
def add_indent(m: Any) -> str:
262
return ' ' + (len(m.group(1)) * ' ')
263
264
# Indent line-continuations
265
cmd = re.sub(r'^(\&+)\s*', add_indent, cmd, flags=re.M)
266
267
cmd = textwrap.dedent(cmd).rstrip() + ';'
268
cmd = re.sub(r'(\S) +', r'\1 ', cmd)
269
cmd = re.sub(r'<comma>', ',', cmd)
270
cmd = re.sub(r'\s+,\s*\.\.\.', ',...', cmd)
271
272
return cmd
273
274
275
def _format_examples(ex: str) -> str:
276
"""Convert examples into sections."""
277
return re.sub(r'(^Example\s+\d+.*$)', r'### \1', ex, flags=re.M)
278
279
280
def _format_arguments(arg: str) -> str:
281
"""Format arguments as subsections."""
282
out = []
283
for line in arg.split('\n'):
284
if line.startswith('<'):
285
out.append(f'### {line.replace("<", "&lt;").replace(">", "&gt;")}')
286
out.append('')
287
else:
288
out.append(line.strip())
289
return '\n'.join(out)
290
291
292
def _to_markdown(txt: str) -> str:
293
"""Convert formatting to markdown."""
294
txt = re.sub(r'`([^`]+)\s+\<([^\>]+)>`_', r'[\1](\2)', txt)
295
txt = txt.replace('``', '`')
296
297
# Format code blocks
298
lines = re.split(r'\n', txt)
299
out = []
300
while lines:
301
line = lines.pop(0)
302
if line.endswith('::'):
303
out.append(line[:-2] + '.')
304
code = []
305
while lines and (not lines[0].strip() or lines[0].startswith(' ')):
306
code.append(lines.pop(0).rstrip())
307
code_str = re.sub(r'^\s*\n', r'', '\n'.join(code).rstrip())
308
out.extend([f'```sql\n{code_str}\n```\n'])
309
else:
310
out.append(line)
311
312
return '\n'.join(out)
313
314
315
def build_help(syntax: str, grammar: str) -> str:
316
"""Build full help text."""
317
cmd = re.match(r'([A-Z0-9_ ]+)', syntax.strip())
318
if not cmd:
319
raise ValueError(f'no command found: {syntax}')
320
321
out = [f'# {cmd.group(1)}\n\n']
322
323
sections: Dict[str, str] = {}
324
grammar = textwrap.dedent(grammar.rstrip())
325
desc_re = re.compile(r'^([A-Z][\S ]+)\s*^\-\-\-\-+\s*$', flags=re.M)
326
if desc_re.search(grammar):
327
_, *txt = desc_re.split(grammar)
328
txt = [x.strip() for x in txt]
329
sections = {}
330
while txt:
331
key = txt.pop(0)
332
value = txt.pop(0)
333
sections[key.lower()] = _to_markdown(value).strip()
334
335
if 'description' in sections:
336
out.extend([sections['description'], '\n\n'])
337
338
out.append(f'## Syntax\n\n```sql{syntax}\n```\n\n')
339
340
if 'arguments' in sections:
341
out.extend([
342
'## Arguments\n\n',
343
_format_arguments(sections['arguments']),
344
'\n\n',
345
])
346
if 'argument' in sections:
347
out.extend([
348
'## Argument\n\n',
349
_format_arguments(sections['argument']),
350
'\n\n',
351
])
352
353
if 'remarks' in sections:
354
out.extend(['## Remarks\n\n', sections['remarks'], '\n\n'])
355
356
if 'examples' in sections:
357
out.extend(['## Examples\n\n', _format_examples(sections['examples']), '\n\n'])
358
elif 'example' in sections:
359
out.extend(['## Example\n\n', _format_examples(sections['example']), '\n\n'])
360
361
if 'see also' in sections:
362
out.extend(['## See Also\n\n', sections['see also'], '\n\n'])
363
364
return ''.join(out).rstrip() + '\n'
365
366
367
def strip_comments(grammar: str) -> str:
368
"""Strip comments from grammar."""
369
desc_re = re.compile(r'(^\s*Description\s*^\s*-----------\s*$)', flags=re.M)
370
grammar = desc_re.split(grammar, maxsplit=1)[0]
371
return re.sub(r'^\s*#.*$', r'', grammar, flags=re.M)
372
373
374
def get_rule_info(grammar: str) -> Dict[str, Any]:
375
"""Compute metadata about rule used in coallescing parsed output."""
376
return dict(
377
n_keywords=len(get_keywords(grammar)),
378
repeats=',...' in grammar,
379
default=False if is_bool(grammar) else [] if ',...' in grammar else None,
380
)
381
382
383
def inject_builtins(grammar: str) -> str:
384
"""Inject complex builtin rules."""
385
for k, v in BUILTINS.items():
386
if re.search(k, grammar):
387
grammar = re.sub(
388
k,
389
k.replace('<', '').replace('>', '').replace('-', '_'),
390
grammar,
391
)
392
grammar += v
393
return grammar
394
395
396
def process_grammar(
397
grammar: str,
398
) -> Tuple[Grammar, Tuple[str, ...], Dict[str, Any], str, str]:
399
"""
400
Convert SQL grammar to a Parsimonious grammar.
401
402
Parameters
403
----------
404
grammar : str
405
The SQL grammar
406
407
Returns
408
-------
409
(Grammar, Tuple[str, ...], Dict[str, Any], str) - Grammar is the parsimonious
410
grammar object. The tuple is a series of the keywords that start the command.
411
The dictionary is a set of metadata about each rule. The final string is
412
a human-readable version of the grammar for documentation and errors.
413
414
"""
415
out = []
416
rules = {}
417
rule_info = {}
418
419
full_grammar = grammar
420
grammar = strip_comments(grammar)
421
grammar = inject_builtins(grammar)
422
command_key = get_keywords(grammar)
423
syntax_txt = build_syntax(grammar)
424
help_txt = build_help(syntax_txt, full_grammar)
425
grammar = build_cmd(grammar)
426
427
# Remove line-continuations
428
grammar = re.sub(r'\n\s*&+', r'', grammar)
429
430
# Make sure grouping characters all have whitespace around them
431
grammar = re.sub(r' *(\[|\{|\||\}|\]) *', r' \1 ', grammar)
432
433
grammar = re.sub(r'\(', r' open_paren ', grammar)
434
grammar = re.sub(r'\)', r' close_paren ', grammar)
435
436
for line in grammar.split('\n'):
437
if not line.strip():
438
continue
439
440
op, sql = line.split('=', 1)
441
op = op.strip()
442
sql = sql.strip()
443
sql = split_unions(sql)
444
445
rules[op] = sql
446
rule_info[op] = get_rule_info(sql)
447
448
# Convert consecutive optionals to a union
449
sql = re.sub(r'\]\s+\[', r' | ', sql)
450
451
# Lower-case keywords and make them case-insensitive
452
sql = re.sub(r'(\b|@+)([A-Z0-9_]+)\b', lower_and_regex, sql)
453
454
# Convert literal strings to 'qs'
455
sql = re.sub(r"'[^']+'", r'qs', sql)
456
457
# Convert special characters to literal tokens
458
sql = re.sub(r'([=]) ', r' eq ', sql)
459
460
# Convert [...] groups to (...)*
461
sql = re.sub(r'\[([^\]]+)\]', process_optional, sql)
462
463
# Convert {...} groups to (...)
464
sql = re.sub(r'\{([^\}]+)\}', process_alternates, sql)
465
466
# Convert <...> to ... (<...> is the form for core types)
467
sql = re.sub(r'<([a-z0-9_]+)>', r'\1', sql)
468
469
# Insert ws between every token to allow for whitespace and comments
470
sql = ' ws '.join(re.split(r'\s+', sql)) + ' ws'
471
472
# Remove ws in optional groupings
473
sql = sql.replace('( ws', '(')
474
sql = sql.replace('| ws', '|')
475
476
# Convert | to /
477
sql = sql.replace('|', '/')
478
479
# Remove ws after operation names, all operations contain ws at the end
480
sql = re.sub(r'(\s+[a-z0-9_]+)\s+ws\b', r'\1', sql)
481
482
# Convert foo,... to foo ("," foo)*
483
sql = re.sub(r'(\S+),...', process_repeats, sql)
484
485
# Remove ws before / and )
486
sql = re.sub(r'(\s*\S+\s+)ws\s+/', r'\1/', sql)
487
sql = re.sub(r'(\s*\S+\s+)ws\s+\)', r'\1)', sql)
488
489
# Make sure every operation ends with ws
490
sql = re.sub(r'\s+ws\s+ws$', r' ws', sql + ' ws')
491
sql = re.sub(r'(\s+ws)*\s+ws\*$', r' ws*', sql)
492
sql = re.sub(r'\s+ws$', r' ws*', sql)
493
sql = re.sub(r'\s+ws\s+\(', r' ws* (', sql)
494
sql = re.sub(r'\)\s+ws\s+', r') ws* ', sql)
495
sql = re.sub(r'\s+ws\s+', r' ws* ', sql)
496
sql = re.sub(r'\?\s+ws\+', r'? ws*', sql)
497
498
# Remove extra ws around eq
499
sql = re.sub(r'ws\+\s*eq\b', r'eq', sql)
500
501
# Remove optional groupings when mandatory groupings are specified
502
sql = re.sub(r'open_paren\s+ws\*\s+open_repeats\?', r'open_paren', sql)
503
sql = re.sub(r'close_repeats\?\s+ws\*\s+close_paren', r'close_paren', sql)
504
sql = re.sub(r'open_paren\s+open_repeats\?', r'open_paren', sql)
505
sql = re.sub(r'close_repeats\?\s+close_paren', r'close_paren', sql)
506
507
out.append(f'{op} = {sql}')
508
509
for k, v in list(rules.items()):
510
while re.search(r' ([a-z0-9_]+) ', v):
511
v = re.sub(r' ([a-z0-9_]+) ', functools.partial(expand_rules, rules), v)
512
rules[k] = v
513
514
for k, v in list(rules.items()):
515
while re.search(r' <([a-z0-9_]+)> ', v):
516
v = re.sub(r' <([a-z0-9_]+)> ', r' \1 ', v)
517
rules[k] = v
518
519
cmds = ' / '.join(x for x in rules if x.endswith('_cmd'))
520
cmds = f'init = ws* ( {cmds} ) ws* ";"? ws*\n'
521
522
grammar = cmds + CORE_GRAMMAR + '\n'.join(out)
523
524
try:
525
return (
526
Grammar(grammar), command_key,
527
rule_info, syntax_txt, help_txt,
528
)
529
except ParseError:
530
print(grammar, file=sys.stderr)
531
raise
532
533
534
def flatten(items: Iterable[Any]) -> List[Any]:
535
"""Flatten a list of iterables."""
536
out = []
537
for x in items:
538
if isinstance(x, (str, bytes, dict)):
539
out.append(x)
540
elif isinstance(x, Iterable):
541
for sub_x in flatten(x):
542
if sub_x is not None:
543
out.append(sub_x)
544
elif x is not None:
545
out.append(x)
546
return out
547
548
549
def merge_dicts(items: List[Dict[str, Any]]) -> Dict[str, Any]:
550
"""Merge list of dictionaries together."""
551
out: Dict[str, Any] = {}
552
for x in items:
553
if isinstance(x, dict):
554
same = list(set(x.keys()).intersection(set(out.keys())))
555
if same:
556
raise ValueError(f"found duplicate rules for '{same[0]}'")
557
out.update(x)
558
return out
559
560
561
class SQLHandler(NodeVisitor):
562
"""Base class for all SQL handler classes."""
563
564
#: Parsimonious grammar object
565
grammar: Grammar = Grammar(CORE_GRAMMAR)
566
567
#: SQL keywords that start the command
568
command_key: Tuple[str, ...] = ()
569
570
#: Metadata about the parse rules
571
rule_info: Dict[str, Any] = {}
572
573
#: Syntax string for use in error messages
574
syntax: str = ''
575
576
#: Full help for the command
577
help: str = ''
578
579
#: Rule validation functions
580
validators: Dict[str, Callable[..., Any]] = {}
581
582
_grammar: str = CORE_GRAMMAR
583
_is_compiled: bool = False
584
_enabled: bool = True
585
_preview: bool = False
586
587
def __init__(self, connection: Connection):
588
self.connection = connection
589
self._handled: Set[str] = set()
590
591
@classmethod
592
def compile(cls, grammar: str = '') -> None:
593
"""
594
Compile the grammar held in the docstring.
595
596
This method modifies attributes on the class: ``grammar``,
597
``command_key``, ``rule_info``, ``syntax``, and ``help``.
598
599
Parameters
600
----------
601
grammar : str, optional
602
Grammar to use instead of docstring
603
604
"""
605
if cls._is_compiled:
606
return
607
608
cls.grammar, cls.command_key, cls.rule_info, cls.syntax, cls.help = \
609
process_grammar(grammar or cls.__doc__ or '')
610
611
cls._grammar = grammar or cls.__doc__ or ''
612
cls._is_compiled = True
613
614
@classmethod
615
def register(cls, overwrite: bool = False) -> None:
616
"""
617
Register the handler class.
618
619
Paraemeters
620
-----------
621
overwrite : bool, optional
622
Overwrite an existing command with the same name?
623
624
"""
625
if not cls._enabled and \
626
os.environ.get('SINGLESTOREDB_FUSION_ENABLE_HIDDEN', '0').lower() not in \
627
['1', 't', 'true', 'y', 'yes']:
628
return
629
630
from . import registry
631
cls.compile()
632
registry.register_handler(cls, overwrite=overwrite)
633
634
def create_result(self) -> result.FusionSQLResult:
635
"""
636
Create a new result object.
637
638
Returns
639
-------
640
FusionSQLResult
641
A new result object for this handler
642
643
"""
644
return result.FusionSQLResult()
645
646
def execute(self, sql: str) -> result.FusionSQLResult:
647
"""
648
Parse the SQL and invoke the handler method.
649
650
Parameters
651
----------
652
sql : str
653
SQL statement to execute
654
655
Returns
656
-------
657
DummySQLResult
658
659
"""
660
if type(self)._preview:
661
warnings.warn(
662
'This is a preview Fusion SQL command. '
663
'The options and syntax may change in the future.',
664
PreviewFeatureWarning, stacklevel=2,
665
)
666
667
type(self).compile()
668
self._handled = set()
669
try:
670
params = self.visit(type(self).grammar.parse(sql))
671
for k, v in params.items():
672
params[k] = self.validate_rule(k, v)
673
674
res = self.run(params)
675
676
self._handled = set()
677
678
if res is not None:
679
res.format_results(self.connection)
680
return res
681
682
res = result.FusionSQLResult()
683
res.set_rows([])
684
res.format_results(self.connection)
685
return res
686
687
except ParseError as exc:
688
s = str(exc)
689
msg = ''
690
m = re.search(r'(The non-matching portion.*$)', s)
691
if m:
692
msg = ' ' + m.group(1)
693
m = re.search(r"(Rule) '.+?'( didn't match at.*$)", s)
694
if m:
695
msg = ' ' + m.group(1) + m.group(2)
696
raise ValueError(
697
f'Could not parse statement.{msg} '
698
'Expecting:\n' + textwrap.indent(type(self).syntax, ' '),
699
)
700
701
@abc.abstractmethod
702
def run(self, params: Dict[str, Any]) -> Optional[result.FusionSQLResult]:
703
"""
704
Run the handler command.
705
706
Parameters
707
----------
708
params : Dict[str, Any]
709
Values parsed from the SQL query. Each rule in the grammar
710
results in a key/value pair in the ``params` dictionary.
711
712
Returns
713
-------
714
SQLResult - tuple containing the column definitions and
715
rows of data in the result
716
717
"""
718
raise NotImplementedError
719
720
def visit_qs(self, node: Node, visited_children: Iterable[Any]) -> Any:
721
"""Quoted strings."""
722
if node is None:
723
return None
724
return flatten(visited_children)[0]
725
726
def visit_compound(self, node: Node, visited_children: Iterable[Any]) -> Any:
727
"""Compound name."""
728
print(visited_children)
729
return flatten(visited_children)[0]
730
731
def visit_number(self, node: Node, visited_children: Iterable[Any]) -> Any:
732
"""Numeric value."""
733
return float(flatten(visited_children)[0])
734
735
def visit_integer(self, node: Node, visited_children: Iterable[Any]) -> Any:
736
"""Integer value."""
737
return int(flatten(visited_children)[0])
738
739
def visit_ws(self, node: Node, visited_children: Iterable[Any]) -> Any:
740
"""Whitespace and comments."""
741
return
742
743
def visit_eq(self, node: Node, visited_children: Iterable[Any]) -> Any:
744
"""Equals sign."""
745
return
746
747
def visit_comma(self, node: Node, visited_children: Iterable[Any]) -> Any:
748
"""Single comma."""
749
return
750
751
def visit_open_paren(self, node: Node, visited_children: Iterable[Any]) -> Any:
752
"""Open parenthesis."""
753
return
754
755
def visit_close_paren(self, node: Node, visited_children: Iterable[Any]) -> Any:
756
"""Close parenthesis."""
757
return
758
759
def visit_open_repeats(self, node: Node, visited_children: Iterable[Any]) -> Any:
760
"""Open repeat grouping."""
761
return
762
763
def visit_close_repeats(self, node: Node, visited_children: Iterable[Any]) -> Any:
764
"""Close repeat grouping."""
765
return
766
767
def visit_init(self, node: Node, visited_children: Iterable[Any]) -> Any:
768
"""Entry point of the grammar."""
769
_, out, *_ = visited_children
770
return out
771
772
def visit_statement(self, node: Node, visited_children: Iterable[Any]) -> Any:
773
out = ' '.join(flatten(visited_children)).strip()
774
return {'statement': out}
775
776
def visit_order_by(self, node: Node, visited_children: Iterable[Any]) -> Any:
777
"""Handle ORDER BY."""
778
by = []
779
ascending = []
780
data = [x for x in flatten(visited_children)[2:] if x]
781
for item in data:
782
value = item.popitem()[-1]
783
if not isinstance(value, list):
784
value = [value]
785
value.append('A')
786
by.append(value[0])
787
ascending.append(value[1].upper().startswith('A'))
788
return {'order_by': {'by': by, 'ascending': ascending}}
789
790
def _delimited(self, node: Node, children: Iterable[Any]) -> Any:
791
children = list(children)
792
items = [children[0]]
793
items.extend(item for _, item in children[1])
794
return items
795
796
def _atomic(self, node: Node, children: Iterable[Any]) -> Any:
797
return list(children)[0]
798
799
# visitors
800
visit_json_value = _atomic
801
visit_json_members = visit_json_items = _delimited
802
803
def visit_json_object(self, node: Node, children: Iterable[Any]) -> Any:
804
_, members, _ = children
805
if isinstance(members, list):
806
members = members[0]
807
else:
808
members = []
809
members = [x for x in members if x != '']
810
return dict(members)
811
812
def visit_json_array(self, node: Node, children: Iterable[Any]) -> Any:
813
_, values, _ = children
814
if isinstance(values, list):
815
values = values[0]
816
else:
817
values = []
818
return values
819
820
def visit_json_mapping(self, node: Node, children: Iterable[Any]) -> Any:
821
key, _, value = children
822
return key, value
823
824
def visit_json_string(self, node: Node, children: Iterable[Any]) -> Any:
825
return json_unescape(node.text)
826
827
def visit_json_number(self, node: Node, children: Iterable[Any]) -> Any:
828
if '.' in node.text:
829
return float(node.text)
830
return int(node.text)
831
832
def visit_json_true_val(self, node: Node, children: Iterable[Any]) -> Any:
833
return True
834
835
def visit_json_false_val(self, node: Node, children: Iterable[Any]) -> Any:
836
return False
837
838
def visit_json_null_val(self, node: Node, children: Iterable[Any]) -> Any:
839
return None
840
841
def generic_visit(self, node: Node, visited_children: Iterable[Any]) -> Any:
842
"""
843
Handle all undefined rules.
844
845
This method processes all user-defined rules. Each rule results in
846
a dictionary with a single key corresponding to the rule name, with
847
a value corresponding to the data value following the rule keywords.
848
849
If no value exists, the value True is used. If the rule is not a
850
rule with possible repeated values, a single value is used. If the
851
rule can have repeated values, a list of values is returned.
852
853
"""
854
if node.expr_name.startswith('json'):
855
return visited_children or node.text
856
857
# Call a grammar rule
858
if node.expr_name in type(self).rule_info:
859
n_keywords = type(self).rule_info[node.expr_name]['n_keywords']
860
repeats = type(self).rule_info[node.expr_name]['repeats']
861
862
# If this is the top-level command, create the final result
863
if node.expr_name.endswith('_cmd'):
864
final = merge_dicts(flatten(visited_children)[n_keywords:])
865
for k, v in type(self).rule_info.items():
866
if k.endswith('_cmd') or k.endswith('_') or k.startswith('_'):
867
continue
868
if k not in final and k not in self._handled:
869
final[k] = BUILTIN_DEFAULTS.get(k, v['default'])
870
return final
871
872
# Filter out stray empty strings
873
out = [x for x in flatten(visited_children)[n_keywords:] if x]
874
875
# Remove underscore prefixes from rule name
876
key_name = re.sub(r'^_+', r'', node.expr_name)
877
878
if repeats or len(out) > 1:
879
self._handled.add(node.expr_name)
880
# If all outputs are dicts, merge them
881
if len(out) > 1 and not repeats:
882
is_dicts = [x for x in out if isinstance(x, dict)]
883
if len(is_dicts) == len(out):
884
return {key_name: merge_dicts(out)}
885
return {key_name: out}
886
887
self._handled.add(node.expr_name)
888
return {key_name: out[0] if out else True}
889
890
if hasattr(node, 'match'):
891
if not visited_children and not node.match.groups():
892
return node.text
893
return visited_children or list(node.match.groups())
894
895
return visited_children or node.text
896
897
def validate_rule(self, rule: str, value: Any) -> Any:
898
"""
899
Validate the value of the given rule.
900
901
Paraemeters
902
-----------
903
rule : str
904
Name of the grammar rule the value belongs to
905
value : Any
906
Value parsed from the query
907
908
Returns
909
-------
910
Any - result of the validator function
911
912
"""
913
validator = type(self).validators.get(rule)
914
if validator is not None:
915
return validator(value)
916
return value
917
918