Path: blob/main/singlestoredb/ai/embeddings.py
798 views
import os1from typing import Any2from typing import Callable3from typing import Optional4from typing import Union56import httpx78from singlestoredb import manage_workspaces9from singlestoredb.management.inference_api import InferenceAPIInfo1011try:12from langchain_openai import OpenAIEmbeddings13except ImportError:14raise ImportError(15'Could not import langchain_openai python package. '16'Please install it with `pip install langchain_openai`.',17)1819try:20from langchain_aws import BedrockEmbeddings21except ImportError:22raise ImportError(23'Could not import langchain-aws python package. '24'Please install it with `pip install langchain-aws`.',25)2627import boto328from botocore import UNSIGNED29from botocore.config import Config303132def SingleStoreEmbeddingsFactory(33model_name: str,34api_key: Optional[str] = None,35http_client: Optional[httpx.Client] = None,36obo_token_getter: Optional[Callable[[], Optional[str]]] = None,37base_url: Optional[str] = None,38hosting_platform: Optional[str] = None,39**kwargs: Any,40) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:41"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).42"""43# handle model info44if base_url is None:45base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')46if hosting_platform is None:47hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')48if base_url is None or hosting_platform is None:49inference_api_manager = (50manage_workspaces().organizations.current.inference_apis51)52info = inference_api_manager.get(model_name=model_name)53if not info.internal_connection_url:54info.internal_connection_url = info.connection_url55else:56info = InferenceAPIInfo(57service_id='',58model_name=model_name,59name='',60connection_url=base_url,61internal_connection_url=base_url,62project_id='',63hosting_platform=hosting_platform,64)65if base_url is not None:66info.connection_url = base_url67info.internal_connection_url = base_url68if hosting_platform is not None:69info.hosting_platform = hosting_platform7071# Extract timeouts from http_client if provided72t = http_client.timeout if http_client is not None else None73connect_timeout = None74read_timeout = None75if t is not None:76if isinstance(t, httpx.Timeout):77if t.connect is not None:78connect_timeout = float(t.connect)79if t.read is not None:80read_timeout = float(t.read)81if connect_timeout is None and read_timeout is not None:82connect_timeout = read_timeout83if read_timeout is None and connect_timeout is not None:84read_timeout = connect_timeout85elif isinstance(t, (int, float)):86connect_timeout = float(t)87read_timeout = float(t)8889if info.hosting_platform == 'Amazon':90# Instantiate Bedrock client91cfg_kwargs = {92'signature_version': UNSIGNED,93'retries': {'max_attempts': 1, 'mode': 'standard'},94}95if read_timeout is not None:96cfg_kwargs['read_timeout'] = read_timeout97if connect_timeout is not None:98cfg_kwargs['connect_timeout'] = connect_timeout99100cfg = Config(**cfg_kwargs)101client = boto3.client(102'bedrock-runtime',103endpoint_url=info.internal_connection_url,104region_name='us-east-1',105aws_access_key_id='placeholder',106aws_secret_access_key='placeholder',107config=cfg,108)109110def _inject_headers(request: Any, **_ignored: Any) -> None:111"""Inject dynamic auth/OBO headers prior to Bedrock sending."""112token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')113token_val = api_key if api_key is not None else token_env_val114if token_val:115request.headers['Authorization'] = f'Bearer {token_val}'116if obo_token_getter is not None:117obo_val = obo_token_getter()118if obo_val:119request.headers['X-S2-OBO'] = obo_val120request.headers.pop('X-Amz-Date', None)121request.headers.pop('X-Amz-Security-Token', None)122123emitter = client._endpoint._event_emitter124emitter.register_first(125'before-send.bedrock-runtime.InvokeModel',126_inject_headers,127)128emitter.register_first(129'before-send.bedrock-runtime.InvokeModelWithResponseStream',130_inject_headers,131)132133return BedrockEmbeddings(134model_id=model_name,135endpoint_url=info.internal_connection_url,136region_name='us-east-1',137aws_access_key_id='placeholder',138aws_secret_access_key='placeholder',139client=client,140**kwargs,141)142143# OpenAI / Azure OpenAI path144token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')145token = api_key if api_key is not None else token_env146147openai_kwargs = dict(148base_url=info.internal_connection_url,149api_key=token,150model=model_name,151)152if http_client is not None:153openai_kwargs['http_client'] = http_client154return OpenAIEmbeddings(155**openai_kwargs,156**kwargs,157)158159160