Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/tests/utils.py
798 views
1
#!/usr/bin/env python
2
# type: ignore
3
"""Utilities for testing."""
4
import glob
5
import logging
6
import os
7
import re
8
import uuid
9
from typing import Any
10
from typing import Dict
11
from typing import List
12
from typing import Tuple
13
from urllib.parse import urlparse
14
15
import singlestoredb as s2
16
from singlestoredb.connection import build_params
17
18
19
logger = logging.getLogger(__name__)
20
21
22
def apply_template(content: str, vars: Dict[str, Any]) -> str:
23
for k, v in vars.items():
24
key = '{{%s}}' % k
25
if key in content:
26
content = content.replace(key, v)
27
return content
28
29
30
def get_server_version(cursor: Any) -> Tuple[int, int]:
31
"""
32
Get the server version as a (major, minor) tuple.
33
34
Parameters
35
----------
36
cursor : Cursor
37
Database cursor to execute queries
38
39
Returns
40
-------
41
(int, int)
42
Tuple of (major_version, minor_version)
43
"""
44
cursor.execute('SELECT @@memsql_version')
45
version_str = cursor.fetchone()[0]
46
# Parse version string like "9.1.2" or "9.1.2-abc123"
47
version_parts = version_str.split('-')[0].split('.')
48
major = int(version_parts[0])
49
minor = int(version_parts[1]) if len(version_parts) > 1 else 0
50
logger.info(f'Detected server version: {major}.{minor} (full: {version_str})')
51
return (major, minor)
52
53
54
def find_version_specific_sql_files(base_dir: str) -> List[Tuple[int, int, str]]:
55
"""
56
Find all version-specific SQL files in the given directory.
57
58
Looks for files matching the pattern test_X_Y.sql where X is major
59
version and Y is minor version.
60
61
Parameters
62
----------
63
base_dir : str
64
Directory to search for SQL files
65
66
Returns
67
-------
68
List[Tuple[int, int, str]]
69
List of (major, minor, filepath) tuples sorted by version
70
"""
71
pattern = os.path.join(base_dir, 'test_*_*.sql')
72
files = []
73
74
for filepath in glob.glob(pattern):
75
filename = os.path.basename(filepath)
76
# Match pattern: test_X_Y.sql
77
match = re.match(r'test_(\d+)_(\d+)\.sql$', filename)
78
if match:
79
major = int(match.group(1))
80
minor = int(match.group(2))
81
files.append((major, minor, filepath))
82
logger.debug(
83
f'Found version-specific SQL file: {filename} '
84
f'(v{major}.{minor})',
85
)
86
87
# Sort by version (major, minor)
88
files.sort()
89
return files
90
91
92
def load_version_specific_sql(
93
cursor: Any,
94
base_dir: str,
95
server_version: Tuple[int, int],
96
template_vars: Dict[str, Any],
97
) -> None:
98
"""
99
Load version-specific SQL files based on server version.
100
101
Parameters
102
----------
103
cursor : Cursor
104
Database cursor to execute queries
105
base_dir : str
106
Directory containing SQL files
107
server_version : Tuple[int, int]
108
Server version as (major, minor)
109
template_vars : Dict[str, Any]
110
Template variables to apply to SQL content
111
"""
112
sql_files = find_version_specific_sql_files(base_dir)
113
server_major, server_minor = server_version
114
115
for file_major, file_minor, filepath in sql_files:
116
# Load if server version >= file version
117
if (
118
server_major > file_major or
119
(server_major == file_major and server_minor >= file_minor)
120
):
121
logger.info(
122
f'Loading version-specific SQL: {os.path.basename(filepath)} '
123
f'(requires {file_major}.{file_minor}, '
124
f'server is {server_major}.{server_minor})',
125
)
126
with open(filepath, 'r') as sql_file:
127
for cmd in sql_file.read().split(';\n'):
128
cmd = apply_template(cmd.strip(), template_vars)
129
if cmd:
130
cmd += ';'
131
cursor.execute(cmd)
132
else:
133
logger.info(
134
f'Skipping version-specific SQL: {os.path.basename(filepath)} '
135
f'(requires {file_major}.{file_minor}, '
136
f'server is {server_major}.{server_minor})',
137
)
138
139
140
def load_sql(sql_file: str) -> str:
141
"""
142
Load a file containing SQL code.
143
144
Parameters
145
----------
146
sql_file : str
147
Name of the SQL file to load.
148
149
Returns
150
-------
151
(str, bool)
152
Name of database created for SQL file and a boolean indicating
153
whether the database already existed (meaning that it should not
154
be deleted when tests are finished).
155
156
"""
157
dbname = None
158
159
# Use an existing database name if given.
160
if 'SINGLESTOREDB_URL' in os.environ:
161
dbname = build_params(host=os.environ['SINGLESTOREDB_URL']).get('database')
162
elif 'SINGLESTOREDB_HOST' in os.environ:
163
dbname = build_params(host=os.environ['SINGLESTOREDB_HOST']).get('database')
164
elif 'SINGLESTOREDB_DATABASE' in os.environ:
165
dbname = os.environ['SINGLESTOREDB_DATBASE']
166
167
# Use initializer URL if given for setup operations.
168
# HTTP can't change databases or execute certain commands like SET GLOBAL,
169
# so we always use the MySQL protocol URL for initialization.
170
args = {'local_infile': True}
171
if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:
172
args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']
173
logger.info(
174
f'load_sql: Using SINGLESTOREDB_INIT_DB_URL for setup: '
175
f'{os.environ["SINGLESTOREDB_INIT_DB_URL"]}',
176
)
177
178
http_port = 0
179
if 'SINGLESTOREDB_URL' in os.environ:
180
url = os.environ['SINGLESTOREDB_URL']
181
if url.startswith('http:') or url.startswith('https:'):
182
urlp = urlparse(url)
183
if urlp.port:
184
http_port = urlp.port
185
186
if 'SINGLESTOREDB_HTTP_PORT' in os.environ:
187
http_port = int(os.environ['SINGLESTOREDB_HTTP_PORT'])
188
189
dbexisted = bool(dbname)
190
191
template_vars = dict(DATABASE_NAME=dbname, TEST_PATH=os.path.dirname(sql_file))
192
193
# Always use the default driver since not all operations are
194
# permitted in the HTTP API.
195
with open(sql_file, 'r') as infile:
196
with s2.connect(**args) as conn:
197
with conn.cursor() as cur:
198
try:
199
cur.execute('SET GLOBAL default_partitions_per_leaf=2')
200
cur.execute('SET GLOBAL log_file_size_partitions=1048576')
201
cur.execute('SET GLOBAL log_file_size_ref_dbs=1048576')
202
except s2.OperationalError:
203
pass
204
205
if not dbname:
206
dbname = 'TEST_{}'.format(uuid.uuid4()).replace('-', '_')
207
cur.execute(f'CREATE DATABASE {dbname};')
208
cur.execute(f'USE {dbname};')
209
210
template_vars['DATABASE_NAME'] = dbname
211
212
# Execute lines in SQL.
213
for cmd in infile.read().split(';\n'):
214
cmd = apply_template(cmd.strip(), template_vars)
215
if cmd:
216
cmd += ';'
217
cur.execute(cmd)
218
219
elif not conn.driver.startswith('http'):
220
cur.execute(f'USE {dbname};')
221
222
# Start HTTP server as needed.
223
if http_port and not conn.driver.startswith('http'):
224
cur.execute(f'SET GLOBAL HTTP_PROXY_PORT={http_port};')
225
cur.execute('SET GLOBAL HTTP_API=ON;')
226
cur.execute('RESTART PROXY;')
227
228
# Load version-specific SQL files (e.g., test_9_1.sql for 9.1+)
229
try:
230
server_version = get_server_version(cur)
231
sql_dir = os.path.dirname(sql_file)
232
load_version_specific_sql(
233
cur,
234
sql_dir,
235
server_version,
236
template_vars,
237
)
238
except Exception as e:
239
logger.warning(
240
f'Failed to load version-specific SQL files: {e}',
241
)
242
243
return dbname, dbexisted
244
245
246
def drop_database(name: str) -> None:
247
"""Drop a database with the given name."""
248
if name:
249
args = {}
250
if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:
251
args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']
252
with s2.connect(**args) as conn:
253
with conn.cursor() as cur:
254
cur.execute(f'DROP DATABASE {name};')
255
256
257
def create_user(name: str, password: str, dbname: str) -> None:
258
"""Create a user for the test database."""
259
if name:
260
args = {}
261
if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:
262
args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']
263
with s2.connect(**args) as conn:
264
with conn.cursor() as cur:
265
cur.execute(f'DROP USER IF EXISTS {name};')
266
cur.execute(f'CREATE USER "{name}"@"%" IDENTIFIED BY "{password}"')
267
cur.execute(f'GRANT ALL ON {dbname}.* to "{name}"@"%"')
268
269
270
def drop_user(name: str) -> None:
271
"""Drop a database with the given name."""
272
if name:
273
args = {}
274
if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:
275
args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']
276
with s2.connect(**args) as conn:
277
with conn.cursor() as cur:
278
cur.execute(f'DROP USER IF EXISTS {name};')
279
280