Path: blob/main/singlestoredb/ai/chat.py
801 views
import os1from typing import Any2from typing import Callable3from typing import Optional4from typing import Union56import httpx78from singlestoredb import manage_workspaces9from singlestoredb.management.inference_api import InferenceAPIInfo1011try:12from langchain_openai import ChatOpenAI13except ImportError:14raise ImportError(15'Could not import langchain_openai python package. '16'Please install it with `pip install langchain_openai`.',17)1819try:20from langchain_aws import ChatBedrockConverse21except ImportError:22raise ImportError(23'Could not import langchain-aws python package. '24'Please install it with `pip install langchain-aws`.',25)2627import boto328from botocore import UNSIGNED29from botocore.config import Config303132def SingleStoreChatFactory(33model_name: str,34api_key: Optional[str] = None,35streaming: bool = True,36http_client: Optional[httpx.Client] = None,37obo_token_getter: Optional[Callable[[], Optional[str]]] = None,38base_url: Optional[str] = None,39hosting_platform: Optional[str] = None,40**kwargs: Any,41) -> Union[ChatOpenAI, ChatBedrockConverse]:42"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).43"""44# handle model info45if base_url is None:46base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')47if hosting_platform is None:48hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')49if base_url is None or hosting_platform is None:50inference_api_manager = (51manage_workspaces().organizations.current.inference_apis52)53info = inference_api_manager.get(model_name=model_name)54if not info.internal_connection_url:55info.internal_connection_url = info.connection_url56else:57info = InferenceAPIInfo(58service_id='',59model_name=model_name,60name='',61connection_url=base_url,62internal_connection_url=base_url,63project_id='',64hosting_platform=hosting_platform,65)66if base_url is not None:67info.connection_url = base_url68info.internal_connection_url = base_url69if hosting_platform is not None:70info.hosting_platform = hosting_platform7172# Extract timeouts from http_client if provided73t = http_client.timeout if http_client is not None else None74connect_timeout = None75read_timeout = None76if t is not None:77if isinstance(t, httpx.Timeout):78if t.connect is not None:79connect_timeout = float(t.connect)80if t.read is not None:81read_timeout = float(t.read)82if connect_timeout is None and read_timeout is not None:83connect_timeout = read_timeout84if read_timeout is None and connect_timeout is not None:85read_timeout = connect_timeout86elif isinstance(t, (int, float)):87connect_timeout = float(t)88read_timeout = float(t)8990if info.hosting_platform == 'Amazon':91# Instantiate Bedrock client92cfg_kwargs = {93'signature_version': UNSIGNED,94'retries': {'max_attempts': 1, 'mode': 'standard'},95}96if read_timeout is not None:97cfg_kwargs['read_timeout'] = read_timeout98if connect_timeout is not None:99cfg_kwargs['connect_timeout'] = connect_timeout100101cfg = Config(**cfg_kwargs)102client = boto3.client(103'bedrock-runtime',104endpoint_url=info.internal_connection_url,105region_name='us-east-1',106aws_access_key_id='placeholder',107aws_secret_access_key='placeholder',108config=cfg,109)110111def _inject_headers(request: Any, **_ignored: Any) -> None:112"""Inject dynamic auth/OBO headers prior to Bedrock sending."""113token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')114token_val = api_key if api_key is not None else token_env_val115if token_val:116request.headers['Authorization'] = f'Bearer {token_val}'117if obo_token_getter is not None:118obo_val = obo_token_getter()119if obo_val:120request.headers['X-S2-OBO'] = obo_val121request.headers.pop('X-Amz-Date', None)122request.headers.pop('X-Amz-Security-Token', None)123124emitter = client._endpoint._event_emitter125emitter.register_first(126'before-send.bedrock-runtime.Converse',127_inject_headers,128)129emitter.register_first(130'before-send.bedrock-runtime.ConverseStream',131_inject_headers,132)133emitter.register_first(134'before-send.bedrock-runtime.InvokeModel',135_inject_headers,136)137emitter.register_first(138'before-send.bedrock-runtime.InvokeModelWithResponseStream',139_inject_headers,140)141142return ChatBedrockConverse(143model_id=model_name,144endpoint_url=info.internal_connection_url,145region_name='us-east-1',146aws_access_key_id='placeholder',147aws_secret_access_key='placeholder',148disable_streaming=not streaming,149client=client,150**kwargs,151)152153# OpenAI / Azure OpenAI path154token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')155token = api_key if api_key is not None else token_env156157openai_kwargs = dict(158base_url=info.internal_connection_url,159api_key=token,160model=model_name,161streaming=streaming,162)163if http_client is not None:164openai_kwargs['http_client'] = http_client165return ChatOpenAI(166**openai_kwargs,167**kwargs,168)169170171