Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/fusion/graphql.py
469 views
1
import itertools
2
import os
3
from typing import Any
4
from typing import Callable
5
from typing import Dict
6
from typing import List
7
from typing import Optional
8
from typing import Tuple
9
from typing import Union
10
11
import requests
12
13
from . import result
14
from .result import FusionSQLResult
15
16
17
API_URL = 'https://backend.singlestore.com/public'
18
19
20
def pass_through(x: Any) -> Any:
21
"""Pass a value through."""
22
return x
23
24
25
def find_path(d: Dict[str, Any], path: str) -> Tuple[bool, Any]:
26
"""
27
Find key path in a dictionary.
28
29
Parameters
30
----------
31
d : Dict[str, Any]
32
Dictionary to search
33
path : str
34
Period-delimited string indicating nested keys
35
36
Returns
37
-------
38
(bool, Any) - bool indicating whether or not the path was found
39
and the result itself
40
41
"""
42
curr = d
43
keys = path.split('.')
44
for i, k in enumerate(keys):
45
if k in curr:
46
curr = curr[k]
47
if not isinstance(curr, dict):
48
break
49
else:
50
return False, None
51
if (i + 1) == len(keys):
52
return True, curr
53
return False, None
54
55
56
class GraphQueryField(object):
57
"""
58
Field in a GraphQuery result.
59
60
Parameters
61
----------
62
path : str
63
Period-delimited path to the result
64
dtype : int, optional
65
MySQL data type of the result, defaults to string
66
converter : function, optional
67
Convert for data value
68
69
"""
70
71
_sort_index_count = itertools.count()
72
73
def __init__(
74
self,
75
path: str,
76
dtype: int = result.STRING,
77
include: Union[str, List[str]] = '',
78
converter: Optional[Callable[[Any], Any]] = pass_through,
79
) -> None:
80
self.path = path
81
self.dtype = dtype
82
self.include = [include] if isinstance(include, str) else include
83
self.include = [x for x in self.include if x]
84
self.converter = converter
85
self._sort_index = next(type(self)._sort_index_count)
86
87
def get_path(self, value: Any) -> Tuple[bool, Any]:
88
"""
89
Retrieve the field path in the given object.
90
91
Parameters
92
----------
93
value : Any
94
Object parsed from nested dictionary object
95
96
Returns
97
-------
98
(bool, Any) - bool indicating whether the path was found and
99
the result itself
100
101
"""
102
found, out = find_path(value, self.path)
103
if self.converter is not None:
104
return found, self.converter(out)
105
return found, out
106
107
108
class GraphQuery(object):
109
"""
110
Base class for all GraphQL classes.
111
112
Parameters
113
----------
114
api_token : str, optional
115
API token to access the GraphQL endpoint
116
api_url : str, optional
117
GraphQL endpoint
118
119
"""
120
121
def __init__(
122
self,
123
api_token: str = '',
124
api_url: str = API_URL,
125
) -> None:
126
self.api_token = api_token
127
self.api_url = api_url
128
129
@classmethod
130
def get_query(cls) -> str:
131
"""Return the GraphQL for the class."""
132
return cls.__doc__ or ''
133
134
@classmethod
135
def get_fields(cls) -> List[Tuple[str, GraphQueryField]]:
136
"""
137
Return fields for the query.
138
139
Parameters
140
----------
141
groups : str
142
List of group characters to include
143
144
Returns
145
-------
146
List[Tuple[str, QueryField]] - tuple pairs of field name and definition
147
148
"""
149
attrs = [(k, v) for k, v in vars(cls).items() if isinstance(v, GraphQueryField)]
150
attrs = list(sorted(attrs, key=lambda x: x[1]._sort_index))
151
return attrs
152
153
def run(
154
self,
155
variables: Optional[Dict[str, Any]] = None,
156
*,
157
filter_expr: str = '',
158
) -> FusionSQLResult:
159
"""
160
Run the query.
161
162
Parameters
163
----------
164
variables : Dict[str, Any], optional
165
Dictionary of substitution parameters
166
167
Returns
168
-------
169
FusionSQLResult
170
171
"""
172
api_token = self.api_token or os.environ.get('SINGLESTOREDB_BACKEND_TOKEN')
173
res = requests.post(
174
self.api_url,
175
headers={
176
'Content-Type': 'application/json',
177
'Authorization': f'Bearer {api_token}',
178
},
179
json={
180
'query': type(self).get_query(),
181
'variables': variables or {},
182
},
183
)
184
185
if res.status_code != 200:
186
raise ValueError(f'an error occurred: {res.text}')
187
188
json = res.json()
189
190
if json['data']:
191
data = json['data'].popitem()[-1]
192
if isinstance(data, Dict):
193
data = [data]
194
else:
195
data = []
196
197
fres = FusionSQLResult()
198
199
rows = []
200
fields = type(self).get_fields()
201
for i, obj in enumerate(data):
202
row = []
203
for name, field in fields:
204
found, value = field.get_path(obj)
205
if found:
206
if i == 0:
207
fres.add_field(name, field.dtype)
208
row.append(value)
209
rows.append(tuple(row))
210
211
fres.set_rows(rows)
212
213
return fres
214
215