CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
sagemathinc

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: sagemathinc/cocalc
Path: blob/master/src/packages/frontend/account/user-defined-llm.tsx
Views: 687
1
import {
2
Alert,
3
Button,
4
Flex,
5
Form,
6
Input,
7
List,
8
Modal,
9
Popconfirm,
10
Select,
11
Skeleton,
12
Space,
13
Tooltip,
14
} from "antd";
15
import { useWatch } from "antd/es/form/Form";
16
import { sortBy } from "lodash";
17
import { FormattedMessage, useIntl } from "react-intl";
18
19
import {
20
useEffect,
21
useState,
22
useTypedRedux,
23
} from "@cocalc/frontend/app-framework";
24
import {
25
A,
26
HelpIcon,
27
Icon,
28
RawPrompt,
29
Text,
30
Title,
31
} from "@cocalc/frontend/components";
32
import { LanguageModelVendorAvatar } from "@cocalc/frontend/components/language-model-icon";
33
import { webapp_client } from "@cocalc/frontend/webapp-client";
34
import { OTHER_SETTINGS_USERDEFINED_LLM as KEY } from "@cocalc/util/db-schema/defaults";
35
import {
36
LLM_PROVIDER,
37
SERVICES,
38
UserDefinedLLM,
39
UserDefinedLLMService,
40
isLLMServiceName,
41
toUserLLMModelName,
42
} from "@cocalc/util/db-schema/llm-utils";
43
import { trunc, unreachable } from "@cocalc/util/misc";
44
45
interface Props {
46
on_change: (name: string, value: any) => void;
47
}
48
49
export function UserDefinedLLMComponent({ on_change }: Props) {
50
const intl = useIntl();
51
const user_defined_llm = useTypedRedux("customize", "user_defined_llm");
52
const other_settings = useTypedRedux("account", "other_settings");
53
const [form] = Form.useForm();
54
const [editLLM, setEditLLM] = useState<UserDefinedLLM | null>(null);
55
const [tmpLLM, setTmpLLM] = useState<UserDefinedLLM | null>(null);
56
const [loading, setLoading] = useState(false);
57
const [llms, setLLMs] = useState<UserDefinedLLM[]>([]);
58
const [error, setError] = useState<string | null>(null);
59
60
const [needAPIKey, setNeedAPIKey] = useState(false);
61
const [needEndpoint, setNeedEndpoint] = useState(false);
62
63
const service: UserDefinedLLMService = useWatch("service", form);
64
useEffect(() => {
65
const v = service === "custom_openai" || service === "ollama";
66
setNeedAPIKey(!v);
67
setNeedEndpoint(v);
68
}, [service]);
69
70
useEffect(() => {
71
setLoading(true);
72
const val = other_settings?.get(KEY) ?? "[]";
73
try {
74
const data: UserDefinedLLM[] = JSON.parse(val);
75
setLLMs(sortBy(data, "id"));
76
} catch (e) {
77
setError(`Error parsing custom LLMs: ${e}`);
78
setLLMs([]);
79
}
80
setLoading(false);
81
}, [other_settings?.get(KEY)]);
82
83
useEffect(() => {
84
if (editLLM != null) {
85
form.setFieldsValue(editLLM);
86
} else {
87
form.resetFields();
88
}
89
}, [editLLM]);
90
91
function getNextID(): number {
92
let id = 0;
93
llms.forEach((m) => (m.id > id ? (id = m.id) : null));
94
return id + 1;
95
}
96
97
function save(next: UserDefinedLLM, oldID: number) {
98
// trim each field in next
99
for (const key in next) {
100
if (typeof next[key] === "string") {
101
next[key] = next[key].trim();
102
}
103
}
104
// set id if not set
105
next.id ??= getNextID();
106
107
const { service, display, model, endpoint } = next;
108
if (
109
!display ||
110
!model ||
111
(needEndpoint && !endpoint) ||
112
(needAPIKey && !next.apiKey)
113
) {
114
setError("Please fill all fields – click the add button and fix it!");
115
return;
116
}
117
if (!SERVICES.includes(service as any)) {
118
setError(`Invalid service: ${service}`);
119
return;
120
}
121
try {
122
// replace an entry with the same ID, if it exists
123
const newModels = llms.filter((m) => m.id !== oldID);
124
newModels.push(next);
125
on_change(KEY, JSON.stringify(newModels));
126
setEditLLM(null);
127
} catch (err) {
128
setError(`Error saving custom LLM: ${err}`);
129
}
130
}
131
132
function deleteLLM(model: string) {
133
try {
134
const newModels = llms.filter((m) => m.model !== model);
135
on_change(KEY, JSON.stringify(newModels));
136
} catch (err) {
137
setError(`Error deleting custom LLM: ${err}`);
138
}
139
}
140
141
function addLLM() {
142
return (
143
<Button
144
block
145
icon={<Icon name="plus-circle-o" />}
146
onClick={() => {
147
if (!error) {
148
setEditLLM({
149
id: getNextID(),
150
service: "custom_openai",
151
display: "",
152
endpoint: "",
153
model: "",
154
apiKey: "",
155
});
156
} else {
157
setEditLLM(tmpLLM);
158
setError(null);
159
}
160
}}
161
>
162
<FormattedMessage
163
id="account.user-defined-llm.add_button.label"
164
defaultMessage="Add your own Language Model"
165
/>
166
</Button>
167
);
168
}
169
170
async function test(llm: UserDefinedLLM) {
171
setLoading(true);
172
Modal.info({
173
closable: true,
174
title: `Test ${llm.display} (${llm.model})`,
175
content: <TestCustomLLM llm={llm} />,
176
okText: "Close",
177
});
178
setLoading(false);
179
}
180
181
function renderList() {
182
return (
183
<List
184
loading={loading}
185
itemLayout="horizontal"
186
dataSource={llms}
187
renderItem={(item: UserDefinedLLM) => {
188
const { display, model, endpoint, service } = item;
189
if (!isLLMServiceName(service)) return null;
190
191
return (
192
<List.Item
193
actions={[
194
<Button
195
icon={<Icon name="pen" />}
196
type="link"
197
onClick={() => {
198
setEditLLM(item);
199
}}
200
>
201
Edit
202
</Button>,
203
<Popconfirm
204
title={`Are you sure you want to delete the LLM ${display} (${model})?`}
205
onConfirm={() => deleteLLM(model)}
206
okText="Yes"
207
cancelText="No"
208
>
209
<Button icon={<Icon name="trash" />} type="link" danger>
210
Delete
211
</Button>
212
</Popconfirm>,
213
<Button
214
icon={<Icon name="play-circle" />}
215
type="link"
216
onClick={() => test(item)}
217
>
218
Test
219
</Button>,
220
]}
221
>
222
<Skeleton avatar title={false} loading={false} active>
223
<Tooltip
224
title={
225
<>
226
Model: {model}
227
<br />
228
Endpoint: {endpoint}
229
<br />
230
Service: {service}
231
</>
232
}
233
>
234
<List.Item.Meta
235
avatar={
236
<LanguageModelVendorAvatar
237
model={toUserLLMModelName(item)}
238
/>
239
}
240
title={display}
241
/>
242
</Tooltip>
243
</Skeleton>
244
</List.Item>
245
);
246
}}
247
/>
248
);
249
}
250
251
function renderExampleModel() {
252
switch (service) {
253
case "custom_openai":
254
case "openai":
255
return "'gpt-4o'";
256
case "ollama":
257
return "'llama3:latest', 'phi3:instruct', ...";
258
case "anthropic":
259
return "'claude-3-sonnet-20240229'";
260
case "mistralai":
261
return "'open-mixtral-8x22b'";
262
case "google":
263
return "'gemini-1.5-flash'";
264
default:
265
unreachable(service);
266
return "'llama3:latest'";
267
}
268
}
269
270
function renderForm() {
271
if (!editLLM) return null;
272
return (
273
<Modal
274
open={editLLM != null}
275
title="Edit Language Model"
276
onOk={() => {
277
const vals = form.getFieldsValue(true);
278
setTmpLLM(vals);
279
save(vals, editLLM.id);
280
setEditLLM(null);
281
}}
282
onCancel={() => {
283
setEditLLM(null);
284
}}
285
>
286
<Form
287
form={form}
288
layout="horizontal"
289
labelCol={{ span: 8 }}
290
wrapperCol={{ span: 16 }}
291
>
292
<Form.Item
293
label="Display Name"
294
name="display"
295
rules={[{ required: true }]}
296
help="e.g. 'MyLLM'"
297
>
298
<Input />
299
</Form.Item>
300
<Form.Item
301
label="Service"
302
name="service"
303
rules={[{ required: true }]}
304
help="Select the kind of server to talk to. Probably 'OpenAI API' or 'Ollama'"
305
>
306
<Select popupMatchSelectWidth={false}>
307
{SERVICES.map((option) => {
308
const { name, desc } = LLM_PROVIDER[option];
309
return (
310
<Select.Option key={option} value={option}>
311
<Tooltip title={desc} placement="right">
312
<Text strong>{name}</Text>: {trunc(desc, 50)}
313
</Tooltip>
314
</Select.Option>
315
);
316
})}
317
</Select>
318
</Form.Item>
319
<Form.Item
320
label="Model Name"
321
name="model"
322
rules={[{ required: true }]}
323
help={`This depends on the available models. e.g. ${renderExampleModel()}.`}
324
>
325
<Input />
326
</Form.Item>
327
<Form.Item
328
label="Endpoint URL"
329
name="endpoint"
330
rules={[{ required: needEndpoint }]}
331
help={
332
needEndpoint
333
? "e.g. 'https://your.ollama.server:11434/' or 'https://api.openai.com/v1'"
334
: "This setting is ignored."
335
}
336
>
337
<Input disabled={!needEndpoint} />
338
</Form.Item>
339
<Form.Item
340
label="API Key"
341
name="apiKey"
342
help="A secret string, which you got from the service provider."
343
rules={[{ required: needAPIKey }]}
344
>
345
<Input />
346
</Form.Item>
347
</Form>
348
</Modal>
349
);
350
}
351
352
function renderError() {
353
if (!error) return null;
354
return <Alert message={error} type="error" closable />;
355
}
356
357
const title = intl.formatMessage({
358
id: "account.user-defined-llm.title",
359
defaultMessage: "Bring your own Language Model",
360
});
361
362
function renderContent() {
363
if (user_defined_llm) {
364
return (
365
<>
366
{renderForm()}
367
{renderList()}
368
{addLLM()}
369
{renderError()}
370
</>
371
);
372
} else {
373
return <Alert banner type="info" message="This feature is disabled." />;
374
}
375
}
376
377
return (
378
<>
379
<Title level={5}>
380
{title}{" "}
381
<HelpIcon style={{ float: "right" }} maxWidth="300px" title={title}>
382
<FormattedMessage
383
id="account.user-defined-llm.info"
384
defaultMessage={`This allows you to call a {llm} of your own.
385
You either need an API key or run it on your own server.
386
Make sure to click on "Test" to check, that the communication to the API actually works.
387
Most likely, the type you are looking for is "Custom OpenAI" or "Ollama".`}
388
values={{
389
llm: (
390
<A href={"https://en.wikipedia.org/wiki/Large_language_model"}>
391
Large Language Model
392
</A>
393
),
394
}}
395
/>
396
</HelpIcon>
397
</Title>
398
399
{renderContent()}
400
</>
401
);
402
}
403
404
function TestCustomLLM({ llm }: { llm: UserDefinedLLM }) {
405
const [querying, setQuerying] = useState<boolean>(false);
406
const [prompt, setPrompt] = useState<string>("Capital of Australia?");
407
const [reply, setReply] = useState<string>("");
408
const [error, setError] = useState<string>("");
409
410
async function doQuery() {
411
setQuerying(true);
412
setError("");
413
setReply("");
414
try {
415
const llmStream = webapp_client.openai_client.queryStream({
416
input: prompt,
417
project_id: null,
418
tag: "userdefined-llm-test",
419
model: toUserLLMModelName(llm),
420
system: "This is a test. Reply briefly.",
421
maxTokens: 100,
422
});
423
424
let reply = "";
425
llmStream.on("token", (token) => {
426
if (token) {
427
reply += token;
428
setReply(reply);
429
} else {
430
setQuerying(false);
431
}
432
});
433
434
llmStream.on("error", (err) => {
435
setError(err?.toString());
436
setQuerying(false);
437
});
438
} catch (e) {
439
setError(e.message);
440
setReply("");
441
setQuerying(false);
442
}
443
}
444
445
// TODO implement a button (or whatever) to query the backend and show the response in real time
446
return (
447
<Space direction="vertical">
448
<Flex vertical={false} align="center" gap={5}>
449
<Flex>Prompt: </Flex>
450
<Input
451
value={prompt}
452
onChange={(e) => setPrompt(e.target.value)}
453
onPressEnter={doQuery}
454
/>
455
<Button loading={querying} type="primary" onClick={doQuery}>
456
Test
457
</Button>
458
</Flex>
459
{reply ? (
460
<>
461
Reply:
462
<RawPrompt input={reply} />
463
</>
464
) : null}
465
{error ? <Alert banner message={error} type="error" /> : null}
466
</Space>
467
);
468
}
469
470