Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/fusion/handler.py
469 views
1
#!/usr/bin/env python3
2
import abc
3
import functools
4
import os
5
import re
6
import sys
7
import textwrap
8
from typing import Any
9
from typing import Callable
10
from typing import Dict
11
from typing import Iterable
12
from typing import List
13
from typing import Optional
14
from typing import Set
15
from typing import Tuple
16
17
from parsimonious import Grammar
18
from parsimonious import ParseError
19
from parsimonious.nodes import Node
20
from parsimonious.nodes import NodeVisitor
21
22
from . import result
23
from ..connection import Connection
24
25
CORE_GRAMMAR = r'''
26
ws = ~r"(\s+|(\s*/\*.*\*/\s*)+)"
27
qs = ~r"\"([^\"]*)\"|'([^\']*)'|([A-Za-z0-9_\-\.]+)|`([^\`]+)`" ws*
28
number = ~r"[-+]?(\d*\.)?\d+(e[-+]?\d+)?"i ws*
29
integer = ~r"-?\d+" ws*
30
comma = ws* "," ws*
31
eq = ws* "=" ws*
32
open_paren = ws* "(" ws*
33
close_paren = ws* ")" ws*
34
open_repeats = ws* ~r"[\(\[\{]" ws*
35
close_repeats = ws* ~r"[\)\]\}]" ws*
36
statement = ~r"[\s\S]*" ws*
37
table = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
38
column = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
39
link_name = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
40
catalog_name = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
41
42
json = ws* json_object ws*
43
json_object = ~r"{\s*" json_members? ~r"\s*}"
44
json_members = json_mapping (~r"\s*,\s*" json_mapping)*
45
json_mapping = json_string ~r"\s*:\s*" json_value
46
json_array = ~r"\[\s*" json_items? ~r"\s*\]"
47
json_items = json_value (~r"\s*,\s*" json_value)*
48
json_value = json_object / json_array / json_string / json_true_val / json_false_val / json_null_val / json_number
49
json_true_val = "true"
50
json_false_val = "false"
51
json_null_val = "null"
52
json_string = ~r"\"[ !#-\[\]-\U0010ffff]*(?:\\(?:[\"\\/bfnrt]|u[0-9A-Fa-f]{4})[ !#-\[\]-\U0010ffff]*)*\""
53
json_number = ~r"-?(0|[1-9][0-9]*)(\.\d*)?([eE][-+]?\d+)?"
54
''' # noqa: E501
55
56
BUILTINS = {
57
'<order-by>': r'''
58
order_by = ORDER BY order_by_key_,...
59
order_by_key_ = '<key>' [ ASC | DESC ]
60
''',
61
'<like>': r'''
62
like = LIKE '<pattern>'
63
''',
64
'<extended>': r'''
65
extended = EXTENDED
66
''',
67
'<limit>': r'''
68
limit = LIMIT <integer>
69
''',
70
'<integer>': '',
71
'<number>': '',
72
'<json>': '',
73
'<table>': '',
74
'<column>': '',
75
'<catalog-name>': '',
76
'<link-name>': '',
77
'<file-type>': r'''
78
file_type = { FILE | FOLDER }
79
''',
80
'<statement>': '',
81
}
82
83
BUILTIN_DEFAULTS = { # type: ignore
84
'order_by': {'by': []},
85
'like': None,
86
'extended': False,
87
'limit': None,
88
'json': {},
89
}
90
91
_json_unesc_re = re.compile(r'\\(["/\\bfnrt]|u[0-9A-Fa-f])')
92
_json_unesc_map = {
93
'"': '"',
94
'/': '/',
95
'\\': '\\',
96
'b': '\b',
97
'f': '\f',
98
'n': '\n',
99
'r': '\r',
100
't': '\t',
101
}
102
103
104
def _json_unescape(m: Any) -> str:
105
c = m.group(1)
106
if c[0] == 'u':
107
return chr(int(c[1:], 16))
108
c2 = _json_unesc_map.get(c)
109
if not c2:
110
raise ValueError(f'invalid escape sequence: {m.group(0)}')
111
return c2
112
113
114
def json_unescape(s: str) -> str:
115
return _json_unesc_re.sub(_json_unescape, s[1:-1])
116
117
118
def get_keywords(grammar: str) -> Tuple[str, ...]:
119
"""Return all all-caps words from the beginning of the line."""
120
m = re.match(r'^\s*((?:[@A-Z0-9_]+)(\s+|$|;))+', grammar)
121
if not m:
122
return tuple()
123
return tuple(re.split(r'\s+', m.group(0).replace(';', '').strip()))
124
125
126
def is_bool(grammar: str) -> bool:
127
"""Determine if the rule is a boolean."""
128
return bool(re.match(r'^[@A-Z0-9_\s*]+$', grammar))
129
130
131
def process_optional(m: Any) -> str:
132
"""Create options or groups of options."""
133
sql = m.group(1).strip()
134
if '|' in sql:
135
return f'( {sql} )*'
136
return f'( {sql} )?'
137
138
139
def process_alternates(m: Any) -> str:
140
"""Make alternates mandatory groups."""
141
sql = m.group(1).strip()
142
if '|' in sql:
143
return f'( {sql} )'
144
raise ValueError(f'alternates must contain "|": {sql}')
145
146
147
def process_repeats(m: Any) -> str:
148
"""Add repeated patterns."""
149
sql = m.group(1).strip()
150
return f'open_repeats? {sql} ws* ( comma {sql} ws* )* close_repeats?'
151
152
153
def lower_and_regex(m: Any) -> str:
154
"""Lowercase and convert literal to regex."""
155
start = m.group(1) or ''
156
sql = m.group(2)
157
return f'~"{start}{sql.lower()}"i'
158
159
160
def split_unions(grammar: str) -> str:
161
"""
162
Convert grammar in the form '[ x ] [ y ]' to '[ x | y ]'.
163
164
Parameters
165
----------
166
grammar : str
167
SQL grammar
168
169
Returns
170
-------
171
str
172
173
"""
174
in_alternate = False
175
out = []
176
for c in grammar:
177
if c == '{':
178
in_alternate = True
179
out.append(c)
180
elif c == '}':
181
in_alternate = False
182
out.append(c)
183
elif not in_alternate and c == '|':
184
out.append(']')
185
out.append(' ')
186
out.append('[')
187
else:
188
out.append(c)
189
return ''.join(out)
190
191
192
def expand_rules(rules: Dict[str, str], m: Any) -> str:
193
"""
194
Return expanded grammar syntax for given rule.
195
196
Parameters
197
----------
198
ops : Dict[str, str]
199
Dictionary of rules in grammar
200
201
Returns
202
-------
203
str
204
205
"""
206
txt = m.group(1)
207
if txt in rules:
208
return f' {rules[txt]} '
209
return f' <{txt}> '
210
211
212
def build_cmd(grammar: str) -> str:
213
"""Pre-process grammar to construct top-level command."""
214
if ';' not in grammar:
215
raise ValueError('a semi-colon exist at the end of the primary rule')
216
217
# Pre-space
218
m = re.match(r'^\s*', grammar)
219
space = m.group(0) if m else ''
220
221
# Split on ';' on a line by itself
222
begin, end = grammar.split(';', 1)
223
224
# Get statement keywords
225
keywords = get_keywords(begin)
226
cmd = '_'.join(x.lower() for x in keywords) + '_cmd'
227
228
# Collapse multi-line to one
229
begin = re.sub(r'\s+', r' ', begin)
230
231
return f'{space}{cmd} ={begin}\n{end}'
232
233
234
def build_syntax(grammar: str) -> str:
235
"""Construct full syntax."""
236
if ';' not in grammar:
237
raise ValueError('a semi-colon exist at the end of the primary rule')
238
239
# Split on ';' on a line by itself
240
cmd, end = grammar.split(';', 1)
241
242
name = ''
243
rules: Dict[str, Any] = {}
244
for line in end.split('\n'):
245
line = line.strip()
246
if line.startswith('&'):
247
rules[name] += '\n' + line
248
continue
249
if not line:
250
continue
251
name, value = line.split('=', 1)
252
name = name.strip()
253
value = value.strip()
254
rules[name] = value
255
256
while re.search(r' [a-z0-9_]+\b', cmd):
257
cmd = re.sub(r' ([a-z0-9_]+)\b', functools.partial(expand_rules, rules), cmd)
258
259
def add_indent(m: Any) -> str:
260
return ' ' + (len(m.group(1)) * ' ')
261
262
# Indent line-continuations
263
cmd = re.sub(r'^(\&+)\s*', add_indent, cmd, flags=re.M)
264
265
cmd = textwrap.dedent(cmd).rstrip() + ';'
266
cmd = re.sub(r'(\S) +', r'\1 ', cmd)
267
cmd = re.sub(r'<comma>', ',', cmd)
268
cmd = re.sub(r'\s+,\s*\.\.\.', ',...', cmd)
269
270
return cmd
271
272
273
def _format_examples(ex: str) -> str:
274
"""Convert examples into sections."""
275
return re.sub(r'(^Example\s+\d+.*$)', r'### \1', ex, flags=re.M)
276
277
278
def _format_arguments(arg: str) -> str:
279
"""Format arguments as subsections."""
280
out = []
281
for line in arg.split('\n'):
282
if line.startswith('<'):
283
out.append(f'### {line.replace("<", "&lt;").replace(">", "&gt;")}')
284
out.append('')
285
else:
286
out.append(line.strip())
287
return '\n'.join(out)
288
289
290
def _to_markdown(txt: str) -> str:
291
"""Convert formatting to markdown."""
292
txt = re.sub(r'`([^`]+)\s+\<([^\>]+)>`_', r'[\1](\2)', txt)
293
txt = txt.replace('``', '`')
294
295
# Format code blocks
296
lines = re.split(r'\n', txt)
297
out = []
298
while lines:
299
line = lines.pop(0)
300
if line.endswith('::'):
301
out.append(line[:-2] + '.')
302
code = []
303
while lines and (not lines[0].strip() or lines[0].startswith(' ')):
304
code.append(lines.pop(0).rstrip())
305
code_str = re.sub(r'^\s*\n', r'', '\n'.join(code).rstrip())
306
out.extend([f'```sql\n{code_str}\n```\n'])
307
else:
308
out.append(line)
309
310
return '\n'.join(out)
311
312
313
def build_help(syntax: str, grammar: str) -> str:
314
"""Build full help text."""
315
cmd = re.match(r'([A-Z0-9_ ]+)', syntax.strip())
316
if not cmd:
317
raise ValueError(f'no command found: {syntax}')
318
319
out = [f'# {cmd.group(1)}\n\n']
320
321
sections: Dict[str, str] = {}
322
grammar = textwrap.dedent(grammar.rstrip())
323
desc_re = re.compile(r'^([A-Z][\S ]+)\s*^\-\-\-\-+\s*$', flags=re.M)
324
if desc_re.search(grammar):
325
_, *txt = desc_re.split(grammar)
326
txt = [x.strip() for x in txt]
327
sections = {}
328
while txt:
329
key = txt.pop(0)
330
value = txt.pop(0)
331
sections[key.lower()] = _to_markdown(value).strip()
332
333
if 'description' in sections:
334
out.extend([sections['description'], '\n\n'])
335
336
out.append(f'## Syntax\n\n```sql{syntax}\n```\n\n')
337
338
if 'arguments' in sections:
339
out.extend([
340
'## Arguments\n\n',
341
_format_arguments(sections['arguments']),
342
'\n\n',
343
])
344
if 'argument' in sections:
345
out.extend([
346
'## Argument\n\n',
347
_format_arguments(sections['argument']),
348
'\n\n',
349
])
350
351
if 'remarks' in sections:
352
out.extend(['## Remarks\n\n', sections['remarks'], '\n\n'])
353
354
if 'examples' in sections:
355
out.extend(['## Examples\n\n', _format_examples(sections['examples']), '\n\n'])
356
elif 'example' in sections:
357
out.extend(['## Example\n\n', _format_examples(sections['example']), '\n\n'])
358
359
if 'see also' in sections:
360
out.extend(['## See Also\n\n', sections['see also'], '\n\n'])
361
362
return ''.join(out).rstrip() + '\n'
363
364
365
def strip_comments(grammar: str) -> str:
366
"""Strip comments from grammar."""
367
desc_re = re.compile(r'(^\s*Description\s*^\s*-----------\s*$)', flags=re.M)
368
grammar = desc_re.split(grammar, maxsplit=1)[0]
369
return re.sub(r'^\s*#.*$', r'', grammar, flags=re.M)
370
371
372
def get_rule_info(grammar: str) -> Dict[str, Any]:
373
"""Compute metadata about rule used in coallescing parsed output."""
374
return dict(
375
n_keywords=len(get_keywords(grammar)),
376
repeats=',...' in grammar,
377
default=False if is_bool(grammar) else [] if ',...' in grammar else None,
378
)
379
380
381
def inject_builtins(grammar: str) -> str:
382
"""Inject complex builtin rules."""
383
for k, v in BUILTINS.items():
384
if re.search(k, grammar):
385
grammar = re.sub(
386
k,
387
k.replace('<', '').replace('>', '').replace('-', '_'),
388
grammar,
389
)
390
grammar += v
391
return grammar
392
393
394
def process_grammar(
395
grammar: str,
396
) -> Tuple[Grammar, Tuple[str, ...], Dict[str, Any], str, str]:
397
"""
398
Convert SQL grammar to a Parsimonious grammar.
399
400
Parameters
401
----------
402
grammar : str
403
The SQL grammar
404
405
Returns
406
-------
407
(Grammar, Tuple[str, ...], Dict[str, Any], str) - Grammar is the parsimonious
408
grammar object. The tuple is a series of the keywords that start the command.
409
The dictionary is a set of metadata about each rule. The final string is
410
a human-readable version of the grammar for documentation and errors.
411
412
"""
413
out = []
414
rules = {}
415
rule_info = {}
416
417
full_grammar = grammar
418
grammar = strip_comments(grammar)
419
grammar = inject_builtins(grammar)
420
command_key = get_keywords(grammar)
421
syntax_txt = build_syntax(grammar)
422
help_txt = build_help(syntax_txt, full_grammar)
423
grammar = build_cmd(grammar)
424
425
# Remove line-continuations
426
grammar = re.sub(r'\n\s*&+', r'', grammar)
427
428
# Make sure grouping characters all have whitespace around them
429
grammar = re.sub(r' *(\[|\{|\||\}|\]) *', r' \1 ', grammar)
430
431
grammar = re.sub(r'\(', r' open_paren ', grammar)
432
grammar = re.sub(r'\)', r' close_paren ', grammar)
433
434
for line in grammar.split('\n'):
435
if not line.strip():
436
continue
437
438
op, sql = line.split('=', 1)
439
op = op.strip()
440
sql = sql.strip()
441
sql = split_unions(sql)
442
443
rules[op] = sql
444
rule_info[op] = get_rule_info(sql)
445
446
# Convert consecutive optionals to a union
447
sql = re.sub(r'\]\s+\[', r' | ', sql)
448
449
# Lower-case keywords and make them case-insensitive
450
sql = re.sub(r'(\b|@+)([A-Z0-9_]+)\b', lower_and_regex, sql)
451
452
# Convert literal strings to 'qs'
453
sql = re.sub(r"'[^']+'", r'qs', sql)
454
455
# Convert special characters to literal tokens
456
sql = re.sub(r'([=]) ', r' eq ', sql)
457
458
# Convert [...] groups to (...)*
459
sql = re.sub(r'\[([^\]]+)\]', process_optional, sql)
460
461
# Convert {...} groups to (...)
462
sql = re.sub(r'\{([^\}]+)\}', process_alternates, sql)
463
464
# Convert <...> to ... (<...> is the form for core types)
465
sql = re.sub(r'<([a-z0-9_]+)>', r'\1', sql)
466
467
# Insert ws between every token to allow for whitespace and comments
468
sql = ' ws '.join(re.split(r'\s+', sql)) + ' ws'
469
470
# Remove ws in optional groupings
471
sql = sql.replace('( ws', '(')
472
sql = sql.replace('| ws', '|')
473
474
# Convert | to /
475
sql = sql.replace('|', '/')
476
477
# Remove ws after operation names, all operations contain ws at the end
478
sql = re.sub(r'(\s+[a-z0-9_]+)\s+ws\b', r'\1', sql)
479
480
# Convert foo,... to foo ("," foo)*
481
sql = re.sub(r'(\S+),...', process_repeats, sql)
482
483
# Remove ws before / and )
484
sql = re.sub(r'(\s*\S+\s+)ws\s+/', r'\1/', sql)
485
sql = re.sub(r'(\s*\S+\s+)ws\s+\)', r'\1)', sql)
486
487
# Make sure every operation ends with ws
488
sql = re.sub(r'\s+ws\s+ws$', r' ws', sql + ' ws')
489
sql = re.sub(r'(\s+ws)*\s+ws\*$', r' ws*', sql)
490
sql = re.sub(r'\s+ws$', r' ws*', sql)
491
sql = re.sub(r'\s+ws\s+\(', r' ws* (', sql)
492
sql = re.sub(r'\)\s+ws\s+', r') ws* ', sql)
493
sql = re.sub(r'\s+ws\s+', r' ws* ', sql)
494
sql = re.sub(r'\?\s+ws\+', r'? ws*', sql)
495
496
# Remove extra ws around eq
497
sql = re.sub(r'ws\+\s*eq\b', r'eq', sql)
498
499
# Remove optional groupings when mandatory groupings are specified
500
sql = re.sub(r'open_paren\s+ws\*\s+open_repeats\?', r'open_paren', sql)
501
sql = re.sub(r'close_repeats\?\s+ws\*\s+close_paren', r'close_paren', sql)
502
sql = re.sub(r'open_paren\s+open_repeats\?', r'open_paren', sql)
503
sql = re.sub(r'close_repeats\?\s+close_paren', r'close_paren', sql)
504
505
out.append(f'{op} = {sql}')
506
507
for k, v in list(rules.items()):
508
while re.search(r' ([a-z0-9_]+) ', v):
509
v = re.sub(r' ([a-z0-9_]+) ', functools.partial(expand_rules, rules), v)
510
rules[k] = v
511
512
for k, v in list(rules.items()):
513
while re.search(r' <([a-z0-9_]+)> ', v):
514
v = re.sub(r' <([a-z0-9_]+)> ', r' \1 ', v)
515
rules[k] = v
516
517
cmds = ' / '.join(x for x in rules if x.endswith('_cmd'))
518
cmds = f'init = ws* ( {cmds} ) ws* ";"? ws*\n'
519
520
grammar = cmds + CORE_GRAMMAR + '\n'.join(out)
521
522
try:
523
return (
524
Grammar(grammar), command_key,
525
rule_info, syntax_txt, help_txt,
526
)
527
except ParseError:
528
print(grammar, file=sys.stderr)
529
raise
530
531
532
def flatten(items: Iterable[Any]) -> List[Any]:
533
"""Flatten a list of iterables."""
534
out = []
535
for x in items:
536
if isinstance(x, (str, bytes, dict)):
537
out.append(x)
538
elif isinstance(x, Iterable):
539
for sub_x in flatten(x):
540
if sub_x is not None:
541
out.append(sub_x)
542
elif x is not None:
543
out.append(x)
544
return out
545
546
547
def merge_dicts(items: List[Dict[str, Any]]) -> Dict[str, Any]:
548
"""Merge list of dictionaries together."""
549
out: Dict[str, Any] = {}
550
for x in items:
551
if isinstance(x, dict):
552
same = list(set(x.keys()).intersection(set(out.keys())))
553
if same:
554
raise ValueError(f"found duplicate rules for '{same[0]}'")
555
out.update(x)
556
return out
557
558
559
class SQLHandler(NodeVisitor):
560
"""Base class for all SQL handler classes."""
561
562
#: Parsimonious grammar object
563
grammar: Grammar = Grammar(CORE_GRAMMAR)
564
565
#: SQL keywords that start the command
566
command_key: Tuple[str, ...] = ()
567
568
#: Metadata about the parse rules
569
rule_info: Dict[str, Any] = {}
570
571
#: Syntax string for use in error messages
572
syntax: str = ''
573
574
#: Full help for the command
575
help: str = ''
576
577
#: Rule validation functions
578
validators: Dict[str, Callable[..., Any]] = {}
579
580
_grammar: str = CORE_GRAMMAR
581
_is_compiled: bool = False
582
_enabled: bool = True
583
584
def __init__(self, connection: Connection):
585
self.connection = connection
586
self._handled: Set[str] = set()
587
588
@classmethod
589
def compile(cls, grammar: str = '') -> None:
590
"""
591
Compile the grammar held in the docstring.
592
593
This method modifies attributes on the class: ``grammar``,
594
``command_key``, ``rule_info``, ``syntax``, and ``help``.
595
596
Parameters
597
----------
598
grammar : str, optional
599
Grammar to use instead of docstring
600
601
"""
602
if cls._is_compiled:
603
return
604
605
cls.grammar, cls.command_key, cls.rule_info, cls.syntax, cls.help = \
606
process_grammar(grammar or cls.__doc__ or '')
607
608
cls._grammar = grammar or cls.__doc__ or ''
609
cls._is_compiled = True
610
611
@classmethod
612
def register(cls, overwrite: bool = False) -> None:
613
"""
614
Register the handler class.
615
616
Paraemeters
617
-----------
618
overwrite : bool, optional
619
Overwrite an existing command with the same name?
620
621
"""
622
if not cls._enabled and \
623
os.environ.get('SINGLESTOREDB_FUSION_ENABLE_HIDDEN', '0').lower() not in \
624
['1', 't', 'true', 'y', 'yes']:
625
return
626
627
from . import registry
628
cls.compile()
629
registry.register_handler(cls, overwrite=overwrite)
630
631
def create_result(self) -> result.FusionSQLResult:
632
"""
633
Create a new result object.
634
635
Returns
636
-------
637
FusionSQLResult
638
A new result object for this handler
639
640
"""
641
return result.FusionSQLResult()
642
643
def execute(self, sql: str) -> result.FusionSQLResult:
644
"""
645
Parse the SQL and invoke the handler method.
646
647
Parameters
648
----------
649
sql : str
650
SQL statement to execute
651
652
Returns
653
-------
654
DummySQLResult
655
656
"""
657
type(self).compile()
658
self._handled = set()
659
try:
660
params = self.visit(type(self).grammar.parse(sql))
661
for k, v in params.items():
662
params[k] = self.validate_rule(k, v)
663
664
res = self.run(params)
665
666
self._handled = set()
667
668
if res is not None:
669
res.format_results(self.connection)
670
return res
671
672
res = result.FusionSQLResult()
673
res.set_rows([])
674
res.format_results(self.connection)
675
return res
676
677
except ParseError as exc:
678
s = str(exc)
679
msg = ''
680
m = re.search(r'(The non-matching portion.*$)', s)
681
if m:
682
msg = ' ' + m.group(1)
683
m = re.search(r"(Rule) '.+?'( didn't match at.*$)", s)
684
if m:
685
msg = ' ' + m.group(1) + m.group(2)
686
raise ValueError(
687
f'Could not parse statement.{msg} '
688
'Expecting:\n' + textwrap.indent(type(self).syntax, ' '),
689
)
690
691
@abc.abstractmethod
692
def run(self, params: Dict[str, Any]) -> Optional[result.FusionSQLResult]:
693
"""
694
Run the handler command.
695
696
Parameters
697
----------
698
params : Dict[str, Any]
699
Values parsed from the SQL query. Each rule in the grammar
700
results in a key/value pair in the ``params` dictionary.
701
702
Returns
703
-------
704
SQLResult - tuple containing the column definitions and
705
rows of data in the result
706
707
"""
708
raise NotImplementedError
709
710
def visit_qs(self, node: Node, visited_children: Iterable[Any]) -> Any:
711
"""Quoted strings."""
712
if node is None:
713
return None
714
return flatten(visited_children)[0]
715
716
def visit_compound(self, node: Node, visited_children: Iterable[Any]) -> Any:
717
"""Compound name."""
718
print(visited_children)
719
return flatten(visited_children)[0]
720
721
def visit_number(self, node: Node, visited_children: Iterable[Any]) -> Any:
722
"""Numeric value."""
723
return float(flatten(visited_children)[0])
724
725
def visit_integer(self, node: Node, visited_children: Iterable[Any]) -> Any:
726
"""Integer value."""
727
return int(flatten(visited_children)[0])
728
729
def visit_ws(self, node: Node, visited_children: Iterable[Any]) -> Any:
730
"""Whitespace and comments."""
731
return
732
733
def visit_eq(self, node: Node, visited_children: Iterable[Any]) -> Any:
734
"""Equals sign."""
735
return
736
737
def visit_comma(self, node: Node, visited_children: Iterable[Any]) -> Any:
738
"""Single comma."""
739
return
740
741
def visit_open_paren(self, node: Node, visited_children: Iterable[Any]) -> Any:
742
"""Open parenthesis."""
743
return
744
745
def visit_close_paren(self, node: Node, visited_children: Iterable[Any]) -> Any:
746
"""Close parenthesis."""
747
return
748
749
def visit_open_repeats(self, node: Node, visited_children: Iterable[Any]) -> Any:
750
"""Open repeat grouping."""
751
return
752
753
def visit_close_repeats(self, node: Node, visited_children: Iterable[Any]) -> Any:
754
"""Close repeat grouping."""
755
return
756
757
def visit_init(self, node: Node, visited_children: Iterable[Any]) -> Any:
758
"""Entry point of the grammar."""
759
_, out, *_ = visited_children
760
return out
761
762
def visit_statement(self, node: Node, visited_children: Iterable[Any]) -> Any:
763
out = ' '.join(flatten(visited_children)).strip()
764
return {'statement': out}
765
766
def visit_order_by(self, node: Node, visited_children: Iterable[Any]) -> Any:
767
"""Handle ORDER BY."""
768
by = []
769
ascending = []
770
data = [x for x in flatten(visited_children)[2:] if x]
771
for item in data:
772
value = item.popitem()[-1]
773
if not isinstance(value, list):
774
value = [value]
775
value.append('A')
776
by.append(value[0])
777
ascending.append(value[1].upper().startswith('A'))
778
return {'order_by': {'by': by, 'ascending': ascending}}
779
780
def _delimited(self, node: Node, children: Iterable[Any]) -> Any:
781
children = list(children)
782
items = [children[0]]
783
items.extend(item for _, item in children[1])
784
return items
785
786
def _atomic(self, node: Node, children: Iterable[Any]) -> Any:
787
return list(children)[0]
788
789
# visitors
790
visit_json_value = _atomic
791
visit_json_members = visit_json_items = _delimited
792
793
def visit_json_object(self, node: Node, children: Iterable[Any]) -> Any:
794
_, members, _ = children
795
if isinstance(members, list):
796
members = members[0]
797
else:
798
members = []
799
members = [x for x in members if x != '']
800
return dict(members)
801
802
def visit_json_array(self, node: Node, children: Iterable[Any]) -> Any:
803
_, values, _ = children
804
if isinstance(values, list):
805
values = values[0]
806
else:
807
values = []
808
return values
809
810
def visit_json_mapping(self, node: Node, children: Iterable[Any]) -> Any:
811
key, _, value = children
812
return key, value
813
814
def visit_json_string(self, node: Node, children: Iterable[Any]) -> Any:
815
return json_unescape(node.text)
816
817
def visit_json_number(self, node: Node, children: Iterable[Any]) -> Any:
818
if '.' in node.text:
819
return float(node.text)
820
return int(node.text)
821
822
def visit_json_true_val(self, node: Node, children: Iterable[Any]) -> Any:
823
return True
824
825
def visit_json_false_val(self, node: Node, children: Iterable[Any]) -> Any:
826
return False
827
828
def visit_json_null_val(self, node: Node, children: Iterable[Any]) -> Any:
829
return None
830
831
def generic_visit(self, node: Node, visited_children: Iterable[Any]) -> Any:
832
"""
833
Handle all undefined rules.
834
835
This method processes all user-defined rules. Each rule results in
836
a dictionary with a single key corresponding to the rule name, with
837
a value corresponding to the data value following the rule keywords.
838
839
If no value exists, the value True is used. If the rule is not a
840
rule with possible repeated values, a single value is used. If the
841
rule can have repeated values, a list of values is returned.
842
843
"""
844
if node.expr_name.startswith('json'):
845
return visited_children or node.text
846
847
# Call a grammar rule
848
if node.expr_name in type(self).rule_info:
849
n_keywords = type(self).rule_info[node.expr_name]['n_keywords']
850
repeats = type(self).rule_info[node.expr_name]['repeats']
851
852
# If this is the top-level command, create the final result
853
if node.expr_name.endswith('_cmd'):
854
final = merge_dicts(flatten(visited_children)[n_keywords:])
855
for k, v in type(self).rule_info.items():
856
if k.endswith('_cmd') or k.endswith('_') or k.startswith('_'):
857
continue
858
if k not in final and k not in self._handled:
859
final[k] = BUILTIN_DEFAULTS.get(k, v['default'])
860
return final
861
862
# Filter out stray empty strings
863
out = [x for x in flatten(visited_children)[n_keywords:] if x]
864
865
# Remove underscore prefixes from rule name
866
key_name = re.sub(r'^_+', r'', node.expr_name)
867
868
if repeats or len(out) > 1:
869
self._handled.add(node.expr_name)
870
# If all outputs are dicts, merge them
871
if len(out) > 1 and not repeats:
872
is_dicts = [x for x in out if isinstance(x, dict)]
873
if len(is_dicts) == len(out):
874
return {key_name: merge_dicts(out)}
875
return {key_name: out}
876
877
self._handled.add(node.expr_name)
878
return {key_name: out[0] if out else True}
879
880
if hasattr(node, 'match'):
881
if not visited_children and not node.match.groups():
882
return node.text
883
return visited_children or list(node.match.groups())
884
885
return visited_children or node.text
886
887
def validate_rule(self, rule: str, value: Any) -> Any:
888
"""
889
Validate the value of the given rule.
890
891
Paraemeters
892
-----------
893
rule : str
894
Name of the grammar rule the value belongs to
895
value : Any
896
Value parsed from the query
897
898
Returns
899
-------
900
Any - result of the validator function
901
902
"""
903
validator = type(self).validators.get(rule)
904
if validator is not None:
905
return validator(value)
906
return value
907
908