Path: blob/main/singlestoredb/tests/utils.py
798 views
#!/usr/bin/env python1# type: ignore2"""Utilities for testing."""3import glob4import logging5import os6import re7import uuid8from typing import Any9from typing import Dict10from typing import List11from typing import Tuple12from urllib.parse import urlparse1314import singlestoredb as s215from singlestoredb.connection import build_params161718logger = logging.getLogger(__name__)192021def apply_template(content: str, vars: Dict[str, Any]) -> str:22for k, v in vars.items():23key = '{{%s}}' % k24if key in content:25content = content.replace(key, v)26return content272829def get_server_version(cursor: Any) -> Tuple[int, int]:30"""31Get the server version as a (major, minor) tuple.3233Parameters34----------35cursor : Cursor36Database cursor to execute queries3738Returns39-------40(int, int)41Tuple of (major_version, minor_version)42"""43cursor.execute('SELECT @@memsql_version')44version_str = cursor.fetchone()[0]45# Parse version string like "9.1.2" or "9.1.2-abc123"46version_parts = version_str.split('-')[0].split('.')47major = int(version_parts[0])48minor = int(version_parts[1]) if len(version_parts) > 1 else 049logger.info(f'Detected server version: {major}.{minor} (full: {version_str})')50return (major, minor)515253def find_version_specific_sql_files(base_dir: str) -> List[Tuple[int, int, str]]:54"""55Find all version-specific SQL files in the given directory.5657Looks for files matching the pattern test_X_Y.sql where X is major58version and Y is minor version.5960Parameters61----------62base_dir : str63Directory to search for SQL files6465Returns66-------67List[Tuple[int, int, str]]68List of (major, minor, filepath) tuples sorted by version69"""70pattern = os.path.join(base_dir, 'test_*_*.sql')71files = []7273for filepath in glob.glob(pattern):74filename = os.path.basename(filepath)75# Match pattern: test_X_Y.sql76match = re.match(r'test_(\d+)_(\d+)\.sql$', filename)77if match:78major = int(match.group(1))79minor = int(match.group(2))80files.append((major, minor, filepath))81logger.debug(82f'Found version-specific SQL file: {filename} '83f'(v{major}.{minor})',84)8586# Sort by version (major, minor)87files.sort()88return files899091def load_version_specific_sql(92cursor: Any,93base_dir: str,94server_version: Tuple[int, int],95template_vars: Dict[str, Any],96) -> None:97"""98Load version-specific SQL files based on server version.99100Parameters101----------102cursor : Cursor103Database cursor to execute queries104base_dir : str105Directory containing SQL files106server_version : Tuple[int, int]107Server version as (major, minor)108template_vars : Dict[str, Any]109Template variables to apply to SQL content110"""111sql_files = find_version_specific_sql_files(base_dir)112server_major, server_minor = server_version113114for file_major, file_minor, filepath in sql_files:115# Load if server version >= file version116if (117server_major > file_major or118(server_major == file_major and server_minor >= file_minor)119):120logger.info(121f'Loading version-specific SQL: {os.path.basename(filepath)} '122f'(requires {file_major}.{file_minor}, '123f'server is {server_major}.{server_minor})',124)125with open(filepath, 'r') as sql_file:126for cmd in sql_file.read().split(';\n'):127cmd = apply_template(cmd.strip(), template_vars)128if cmd:129cmd += ';'130cursor.execute(cmd)131else:132logger.info(133f'Skipping version-specific SQL: {os.path.basename(filepath)} '134f'(requires {file_major}.{file_minor}, '135f'server is {server_major}.{server_minor})',136)137138139def load_sql(sql_file: str) -> str:140"""141Load a file containing SQL code.142143Parameters144----------145sql_file : str146Name of the SQL file to load.147148Returns149-------150(str, bool)151Name of database created for SQL file and a boolean indicating152whether the database already existed (meaning that it should not153be deleted when tests are finished).154155"""156dbname = None157158# Use an existing database name if given.159if 'SINGLESTOREDB_URL' in os.environ:160dbname = build_params(host=os.environ['SINGLESTOREDB_URL']).get('database')161elif 'SINGLESTOREDB_HOST' in os.environ:162dbname = build_params(host=os.environ['SINGLESTOREDB_HOST']).get('database')163elif 'SINGLESTOREDB_DATABASE' in os.environ:164dbname = os.environ['SINGLESTOREDB_DATBASE']165166# Use initializer URL if given for setup operations.167# HTTP can't change databases or execute certain commands like SET GLOBAL,168# so we always use the MySQL protocol URL for initialization.169args = {'local_infile': True}170if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:171args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']172logger.info(173f'load_sql: Using SINGLESTOREDB_INIT_DB_URL for setup: '174f'{os.environ["SINGLESTOREDB_INIT_DB_URL"]}',175)176177http_port = 0178if 'SINGLESTOREDB_URL' in os.environ:179url = os.environ['SINGLESTOREDB_URL']180if url.startswith('http:') or url.startswith('https:'):181urlp = urlparse(url)182if urlp.port:183http_port = urlp.port184185if 'SINGLESTOREDB_HTTP_PORT' in os.environ:186http_port = int(os.environ['SINGLESTOREDB_HTTP_PORT'])187188dbexisted = bool(dbname)189190template_vars = dict(DATABASE_NAME=dbname, TEST_PATH=os.path.dirname(sql_file))191192# Always use the default driver since not all operations are193# permitted in the HTTP API.194with open(sql_file, 'r') as infile:195with s2.connect(**args) as conn:196with conn.cursor() as cur:197try:198cur.execute('SET GLOBAL default_partitions_per_leaf=2')199cur.execute('SET GLOBAL log_file_size_partitions=1048576')200cur.execute('SET GLOBAL log_file_size_ref_dbs=1048576')201except s2.OperationalError:202pass203204if not dbname:205dbname = 'TEST_{}'.format(uuid.uuid4()).replace('-', '_')206cur.execute(f'CREATE DATABASE {dbname};')207cur.execute(f'USE {dbname};')208209template_vars['DATABASE_NAME'] = dbname210211# Execute lines in SQL.212for cmd in infile.read().split(';\n'):213cmd = apply_template(cmd.strip(), template_vars)214if cmd:215cmd += ';'216cur.execute(cmd)217218elif not conn.driver.startswith('http'):219cur.execute(f'USE {dbname};')220221# Start HTTP server as needed.222if http_port and not conn.driver.startswith('http'):223cur.execute(f'SET GLOBAL HTTP_PROXY_PORT={http_port};')224cur.execute('SET GLOBAL HTTP_API=ON;')225cur.execute('RESTART PROXY;')226227# Load version-specific SQL files (e.g., test_9_1.sql for 9.1+)228try:229server_version = get_server_version(cur)230sql_dir = os.path.dirname(sql_file)231load_version_specific_sql(232cur,233sql_dir,234server_version,235template_vars,236)237except Exception as e:238logger.warning(239f'Failed to load version-specific SQL files: {e}',240)241242return dbname, dbexisted243244245def drop_database(name: str) -> None:246"""Drop a database with the given name."""247if name:248args = {}249if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:250args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']251with s2.connect(**args) as conn:252with conn.cursor() as cur:253cur.execute(f'DROP DATABASE {name};')254255256def create_user(name: str, password: str, dbname: str) -> None:257"""Create a user for the test database."""258if name:259args = {}260if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:261args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']262with s2.connect(**args) as conn:263with conn.cursor() as cur:264cur.execute(f'DROP USER IF EXISTS {name};')265cur.execute(f'CREATE USER "{name}"@"%" IDENTIFIED BY "{password}"')266cur.execute(f'GRANT ALL ON {dbname}.* to "{name}"@"%"')267268269def drop_user(name: str) -> None:270"""Drop a database with the given name."""271if name:272args = {}273if 'SINGLESTOREDB_INIT_DB_URL' in os.environ:274args['host'] = os.environ['SINGLESTOREDB_INIT_DB_URL']275with s2.connect(**args) as conn:276with conn.cursor() as cur:277cur.execute(f'DROP USER IF EXISTS {name};')278279280