Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/mysql/converters.py
469 views
1
import datetime
2
import time
3
from decimal import Decimal
4
from typing import Any
5
from typing import Callable
6
from typing import Dict
7
from typing import Optional
8
from typing import Tuple
9
from typing import Union
10
11
from ..converters import converters as decoders
12
from .err import ProgrammingError
13
14
try:
15
import numpy as np
16
has_numpy = True
17
except ImportError:
18
has_numpy = False
19
20
try:
21
import shapely.geometry
22
import shapely.wkt
23
has_shapely = True
24
except ImportError:
25
has_shapely = False
26
27
try:
28
import pygeos
29
has_pygeos = True
30
except ImportError:
31
has_pygeos = False
32
33
34
Encoders = Dict[type, Callable[..., Union[str, Dict[str, str]]]]
35
36
37
def escape_item(val: Any, charset: str, mapping: Optional[Encoders] = None) -> str:
38
if mapping is None:
39
mapping = encoders
40
encoder = mapping.get(type(val), None)
41
42
# Fallback to default when no encoder found
43
if encoder is None:
44
try:
45
encoder = mapping[str]
46
except KeyError:
47
raise TypeError('no default type converter defined')
48
49
if encoder in (escape_dict, escape_sequence):
50
val = encoder(val, charset, mapping)
51
else:
52
val = encoder(val, mapping)
53
return val
54
55
56
def escape_dict(
57
val: Dict[str, Any],
58
charset: str,
59
mapping: Optional[Encoders] = None,
60
) -> Dict[str, str]:
61
n = {}
62
for k, v in val.items():
63
quoted = escape_item(v, charset, mapping)
64
n[k] = quoted
65
return n
66
67
68
def escape_sequence(
69
val: Any,
70
charset: str,
71
mapping: Optional[Encoders] = None,
72
) -> str:
73
n = []
74
for item in val:
75
quoted = escape_item(item, charset, mapping)
76
n.append(quoted)
77
return '(' + ','.join(n) + ')'
78
79
80
def escape_set(val: Any, charset: str, mapping: Optional[Encoders] = None) -> str:
81
return ','.join([escape_item(x, charset, mapping) for x in val])
82
83
84
def escape_bool(value: Any, mapping: Optional[Encoders] = None) -> str:
85
return str(int(value))
86
87
88
def escape_int(value: Any, mapping: Optional[Encoders] = None) -> str:
89
return str(value)
90
91
92
def escape_float(
93
value: Any,
94
mapping: Optional[Encoders] = None,
95
nan_as_null: bool = False,
96
inf_as_null: bool = False,
97
) -> str:
98
s = repr(value)
99
if s == 'nan':
100
if nan_as_null:
101
return 'NULL'
102
raise ProgrammingError(0, '%s can not be used with SingleStoreDB' % s)
103
if s == 'inf':
104
if inf_as_null:
105
return 'NULL'
106
raise ProgrammingError(0, '%s can not be used with SingleStoreDB' % s)
107
if 'e' not in s:
108
s += 'e0'
109
return s
110
111
112
_escape_table = [chr(x) for x in range(128)]
113
_escape_table[0] = '\\0'
114
_escape_table[ord('\\')] = '\\\\'
115
_escape_table[ord('\n')] = '\\n'
116
_escape_table[ord('\r')] = '\\r'
117
_escape_table[ord('\032')] = '\\Z'
118
_escape_table[ord('"')] = '\\"'
119
_escape_table[ord("'")] = "\\'"
120
121
122
def escape_string(value: str, mapping: Optional[Encoders] = None) -> str:
123
"""
124
Escapes *value* without adding quote.
125
126
Value should be unicode
127
128
"""
129
return value.translate(_escape_table)
130
131
132
def escape_bytes_prefixed(value: bytes, mapping: Optional[Encoders] = None) -> str:
133
return "_binary X'{}'".format(value.hex())
134
135
136
def escape_bytes(value: bytes, mapping: Optional[Encoders] = None) -> str:
137
return "X'{}'".format(value.hex())
138
139
140
def escape_str(value: str, mapping: Optional[Encoders] = None) -> str:
141
return "'{}'".format(escape_string(str(value), mapping))
142
143
144
def escape_None(value: str, mapping: Optional[Encoders] = None) -> str:
145
return 'NULL'
146
147
148
def escape_timedelta(obj: datetime.timedelta, mapping: Optional[Encoders] = None) -> str:
149
seconds = int(obj.seconds) % 60
150
minutes = int(obj.seconds // 60) % 60
151
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
152
if obj.microseconds:
153
fmt = "'{0:02d}:{1:02d}:{2:02d}.{3:06d}'"
154
else:
155
fmt = "'{0:02d}:{1:02d}:{2:02d}'"
156
return fmt.format(hours, minutes, seconds, obj.microseconds)
157
158
159
def escape_time(obj: datetime.time, mapping: Optional[Encoders] = None) -> str:
160
if obj.microsecond:
161
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
162
else:
163
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'"
164
return fmt.format(obj)
165
166
167
def escape_datetime(obj: datetime.datetime, mapping: Optional[Encoders] = None) -> str:
168
if obj.microsecond:
169
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} " \
170
"{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
171
else:
172
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} " \
173
"{0.hour:02}:{0.minute:02}:{0.second:02}'"
174
return fmt.format(obj)
175
176
177
def escape_date(obj: datetime.date, mapping: Optional[Encoders] = None) -> str:
178
fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'"
179
return fmt.format(obj)
180
181
182
def escape_struct_time(obj: Tuple[Any, ...], mapping: Optional[Encoders] = None) -> str:
183
return escape_datetime(datetime.datetime(*obj[:6]))
184
185
186
def Decimal2Literal(o: Any, d: Any) -> str:
187
return format(o, 'f')
188
189
190
def through(x: Any) -> Any:
191
return x
192
193
194
# def convert_bit(b):
195
# b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
196
# return struct.unpack(">Q", b)[0]
197
#
198
# the snippet above is right, but MySQLdb doesn't process bits,
199
# so we shouldn't either
200
convert_bit = through
201
202
203
encoders: Encoders = {
204
bool: escape_bool,
205
int: escape_int,
206
float: escape_float,
207
str: escape_str,
208
bytes: escape_bytes,
209
tuple: escape_sequence,
210
list: escape_sequence,
211
set: escape_sequence,
212
frozenset: escape_sequence,
213
dict: escape_dict,
214
type(None): escape_None,
215
datetime.date: escape_date,
216
datetime.datetime: escape_datetime,
217
datetime.timedelta: escape_timedelta,
218
datetime.time: escape_time,
219
time.struct_time: escape_struct_time,
220
Decimal: Decimal2Literal,
221
}
222
223
if has_numpy:
224
225
def escape_numpy(value: Any, mapping: Optional[Encoders] = None) -> str:
226
"""Convert numpy arrays to vectors of bytes."""
227
return escape_bytes(value.tobytes(), mapping=mapping)
228
229
encoders[np.ndarray] = escape_numpy
230
encoders[np.float16] = escape_float
231
encoders[np.float32] = escape_float
232
encoders[np.float64] = escape_float
233
if hasattr(np, 'float128'):
234
encoders[np.float128] = escape_float
235
encoders[np.uint] = escape_int
236
encoders[np.uint8] = escape_int
237
encoders[np.uint16] = escape_int
238
encoders[np.uint32] = escape_int
239
encoders[np.uint64] = escape_int
240
encoders[np.integer] = escape_int
241
encoders[np.int_] = escape_int
242
encoders[np.int8] = escape_int
243
encoders[np.int16] = escape_int
244
encoders[np.int32] = escape_int
245
encoders[np.int64] = escape_int
246
247
if has_shapely:
248
249
def escape_shapely(value: Any, mapping: Optional[Encoders] = None) -> str:
250
"""Convert shapely geo objects."""
251
return escape_str(shapely.wkt.dumps(value), mapping=mapping)
252
253
encoders[shapely.geometry.Polygon] = escape_shapely
254
encoders[shapely.geometry.Point] = escape_shapely
255
encoders[shapely.geometry.LineString] = escape_shapely
256
257
if has_pygeos:
258
259
def escape_pygeos(value: Any, mapping: Optional[Encoders] = None) -> str:
260
"""Convert pygeos objects."""
261
return escape_str(pygeos.io.to_wkt(value), mapping=mapping)
262
263
encoders[pygeos.Geometry] = escape_pygeos
264
265
266
# for MySQLdb compatibility
267
conversions = encoders.copy() # type: ignore
268
conversions.update(decoders) # type: ignore
269
Thing2Literal = escape_str
270
271
# Run doctests with `pytest --doctest-modules pymysql/converters.py`
272
273