Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
singlestore-labs
GitHub Repository: singlestore-labs/singlestoredb-python
Path: blob/main/singlestoredb/ai/chat.py
801 views
1
import os
2
from typing import Any
3
from typing import Callable
4
from typing import Optional
5
from typing import Union
6
7
import httpx
8
9
from singlestoredb import manage_workspaces
10
from singlestoredb.management.inference_api import InferenceAPIInfo
11
12
try:
13
from langchain_openai import ChatOpenAI
14
except ImportError:
15
raise ImportError(
16
'Could not import langchain_openai python package. '
17
'Please install it with `pip install langchain_openai`.',
18
)
19
20
try:
21
from langchain_aws import ChatBedrockConverse
22
except ImportError:
23
raise ImportError(
24
'Could not import langchain-aws python package. '
25
'Please install it with `pip install langchain-aws`.',
26
)
27
28
import boto3
29
from botocore import UNSIGNED
30
from botocore.config import Config
31
32
33
def SingleStoreChatFactory(
34
model_name: str,
35
api_key: Optional[str] = None,
36
streaming: bool = True,
37
http_client: Optional[httpx.Client] = None,
38
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
39
base_url: Optional[str] = None,
40
hosting_platform: Optional[str] = None,
41
**kwargs: Any,
42
) -> Union[ChatOpenAI, ChatBedrockConverse]:
43
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
44
"""
45
# handle model info
46
if base_url is None:
47
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
48
if hosting_platform is None:
49
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
50
if base_url is None or hosting_platform is None:
51
inference_api_manager = (
52
manage_workspaces().organizations.current.inference_apis
53
)
54
info = inference_api_manager.get(model_name=model_name)
55
if not info.internal_connection_url:
56
info.internal_connection_url = info.connection_url
57
else:
58
info = InferenceAPIInfo(
59
service_id='',
60
model_name=model_name,
61
name='',
62
connection_url=base_url,
63
internal_connection_url=base_url,
64
project_id='',
65
hosting_platform=hosting_platform,
66
)
67
if base_url is not None:
68
info.connection_url = base_url
69
info.internal_connection_url = base_url
70
if hosting_platform is not None:
71
info.hosting_platform = hosting_platform
72
73
# Extract timeouts from http_client if provided
74
t = http_client.timeout if http_client is not None else None
75
connect_timeout = None
76
read_timeout = None
77
if t is not None:
78
if isinstance(t, httpx.Timeout):
79
if t.connect is not None:
80
connect_timeout = float(t.connect)
81
if t.read is not None:
82
read_timeout = float(t.read)
83
if connect_timeout is None and read_timeout is not None:
84
connect_timeout = read_timeout
85
if read_timeout is None and connect_timeout is not None:
86
read_timeout = connect_timeout
87
elif isinstance(t, (int, float)):
88
connect_timeout = float(t)
89
read_timeout = float(t)
90
91
if info.hosting_platform == 'Amazon':
92
# Instantiate Bedrock client
93
cfg_kwargs = {
94
'signature_version': UNSIGNED,
95
'retries': {'max_attempts': 1, 'mode': 'standard'},
96
}
97
if read_timeout is not None:
98
cfg_kwargs['read_timeout'] = read_timeout
99
if connect_timeout is not None:
100
cfg_kwargs['connect_timeout'] = connect_timeout
101
102
cfg = Config(**cfg_kwargs)
103
client = boto3.client(
104
'bedrock-runtime',
105
endpoint_url=info.internal_connection_url,
106
region_name='us-east-1',
107
aws_access_key_id='placeholder',
108
aws_secret_access_key='placeholder',
109
config=cfg,
110
)
111
112
def _inject_headers(request: Any, **_ignored: Any) -> None:
113
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
114
token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')
115
token_val = api_key if api_key is not None else token_env_val
116
if token_val:
117
request.headers['Authorization'] = f'Bearer {token_val}'
118
if obo_token_getter is not None:
119
obo_val = obo_token_getter()
120
if obo_val:
121
request.headers['X-S2-OBO'] = obo_val
122
request.headers.pop('X-Amz-Date', None)
123
request.headers.pop('X-Amz-Security-Token', None)
124
125
emitter = client._endpoint._event_emitter
126
emitter.register_first(
127
'before-send.bedrock-runtime.Converse',
128
_inject_headers,
129
)
130
emitter.register_first(
131
'before-send.bedrock-runtime.ConverseStream',
132
_inject_headers,
133
)
134
emitter.register_first(
135
'before-send.bedrock-runtime.InvokeModel',
136
_inject_headers,
137
)
138
emitter.register_first(
139
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
140
_inject_headers,
141
)
142
143
return ChatBedrockConverse(
144
model_id=model_name,
145
endpoint_url=info.internal_connection_url,
146
region_name='us-east-1',
147
aws_access_key_id='placeholder',
148
aws_secret_access_key='placeholder',
149
disable_streaming=not streaming,
150
client=client,
151
**kwargs,
152
)
153
154
# OpenAI / Azure OpenAI path
155
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
156
token = api_key if api_key is not None else token_env
157
158
openai_kwargs = dict(
159
base_url=info.internal_connection_url,
160
api_key=token,
161
model=model_name,
162
streaming=streaming,
163
)
164
if http_client is not None:
165
openai_kwargs['http_client'] = http_client
166
return ChatOpenAI(
167
**openai_kwargs,
168
**kwargs,
169
)
170
171