Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/ai/embeddings.py
801 views
1
import os
2
from typing import Any
3
from typing import Callable
4
from typing import Optional
5
from typing import Union
6
7
import httpx
8
9
from singlestoredb import manage_workspaces
10
from singlestoredb.management.inference_api import InferenceAPIInfo
11
12
try:
13
from langchain_openai import OpenAIEmbeddings
14
except ImportError:
15
raise ImportError(
16
'Could not import langchain_openai python package. '
17
'Please install it with `pip install langchain_openai`.',
18
)
19
20
try:
21
from langchain_aws import BedrockEmbeddings
22
except ImportError:
23
raise ImportError(
24
'Could not import langchain-aws python package. '
25
'Please install it with `pip install langchain-aws`.',
26
)
27
28
import boto3
29
from botocore import UNSIGNED
30
from botocore.config import Config
31
32
33
def SingleStoreEmbeddingsFactory(
34
model_name: str,
35
api_key: Optional[str] = None,
36
http_client: Optional[httpx.Client] = None,
37
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
38
base_url: Optional[str] = None,
39
hosting_platform: Optional[str] = None,
40
**kwargs: Any,
41
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
42
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
43
"""
44
# handle model info
45
if base_url is None:
46
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
47
if hosting_platform is None:
48
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
49
if base_url is None or hosting_platform is None:
50
inference_api_manager = (
51
manage_workspaces().organizations.current.inference_apis
52
)
53
info = inference_api_manager.get(model_name=model_name)
54
if not info.internal_connection_url:
55
info.internal_connection_url = info.connection_url
56
else:
57
info = InferenceAPIInfo(
58
service_id='',
59
model_name=model_name,
60
name='',
61
connection_url=base_url,
62
internal_connection_url=base_url,
63
project_id='',
64
hosting_platform=hosting_platform,
65
)
66
if base_url is not None:
67
info.connection_url = base_url
68
info.internal_connection_url = base_url
69
if hosting_platform is not None:
70
info.hosting_platform = hosting_platform
71
72
# Extract timeouts from http_client if provided
73
t = http_client.timeout if http_client is not None else None
74
connect_timeout = None
75
read_timeout = None
76
if t is not None:
77
if isinstance(t, httpx.Timeout):
78
if t.connect is not None:
79
connect_timeout = float(t.connect)
80
if t.read is not None:
81
read_timeout = float(t.read)
82
if connect_timeout is None and read_timeout is not None:
83
connect_timeout = read_timeout
84
if read_timeout is None and connect_timeout is not None:
85
read_timeout = connect_timeout
86
elif isinstance(t, (int, float)):
87
connect_timeout = float(t)
88
read_timeout = float(t)
89
90
if info.hosting_platform == 'Amazon':
91
# Instantiate Bedrock client
92
cfg_kwargs = {
93
'signature_version': UNSIGNED,
94
'retries': {'max_attempts': 1, 'mode': 'standard'},
95
}
96
if read_timeout is not None:
97
cfg_kwargs['read_timeout'] = read_timeout
98
if connect_timeout is not None:
99
cfg_kwargs['connect_timeout'] = connect_timeout
100
101
cfg = Config(**cfg_kwargs)
102
client = boto3.client(
103
'bedrock-runtime',
104
endpoint_url=info.internal_connection_url,
105
region_name='us-east-1',
106
aws_access_key_id='placeholder',
107
aws_secret_access_key='placeholder',
108
config=cfg,
109
)
110
111
def _inject_headers(request: Any, **_ignored: Any) -> None:
112
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
113
token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')
114
token_val = api_key if api_key is not None else token_env_val
115
if token_val:
116
request.headers['Authorization'] = f'Bearer {token_val}'
117
if obo_token_getter is not None:
118
obo_val = obo_token_getter()
119
if obo_val:
120
request.headers['X-S2-OBO'] = obo_val
121
request.headers.pop('X-Amz-Date', None)
122
request.headers.pop('X-Amz-Security-Token', None)
123
124
emitter = client._endpoint._event_emitter
125
emitter.register_first(
126
'before-send.bedrock-runtime.InvokeModel',
127
_inject_headers,
128
)
129
emitter.register_first(
130
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
131
_inject_headers,
132
)
133
134
return BedrockEmbeddings(
135
model_id=model_name,
136
endpoint_url=info.internal_connection_url,
137
region_name='us-east-1',
138
aws_access_key_id='placeholder',
139
aws_secret_access_key='placeholder',
140
client=client,
141
**kwargs,
142
)
143
144
# OpenAI / Azure OpenAI path
145
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
146
token = api_key if api_key is not None else token_env
147
148
openai_kwargs = dict(
149
base_url=info.internal_connection_url,
150
api_key=token,
151
model=model_name,
152
)
153
if http_client is not None:
154
openai_kwargs['http_client'] = http_client
155
return OpenAIEmbeddings(
156
**openai_kwargs,
157
**kwargs,
158
)
159
160