Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/llm.go
3431 views
1
package client
2
3
import (
4
"context"
5
"encoding/json"
6
"errors"
7
"fmt"
8
"github.com/kardolus/chatgpt-cli/api"
9
"github.com/kardolus/chatgpt-cli/history"
10
"sort"
11
"strings"
12
)
13
14
const (
15
ErrEmptyResponse = "empty response"
16
ErrRealTime = "model %q requires the Realtime API (WebSocket/WebRTC) and is not supported yet"
17
ErrWebSearch = "model %q is not compatible with the web search feature"
18
SearchModelPattern = "-search"
19
gptPrefix = "gpt"
20
o1Prefix = "o1"
21
o1ProPattern = "o1-pro"
22
gpt5Pattern = "gpt-5"
23
realTimePattern = "realtime"
24
messageType = "message"
25
outputTextType = "output_text"
26
)
27
28
// ListModels retrieves a list of all available models from the OpenAI API.
29
// The models are returned as a slice of strings, each entry representing a model ID.
30
// Models that have an ID starting with 'gpt' are included.
31
// The currently active model is marked with an asterisk (*) in the list.
32
// In case of an error during the retrieval or processing of the models,
33
// the method returns an error. If the API response is empty, an error is returned as well.
34
func (c *Client) ListModels() ([]string, error) {
35
var result []string
36
37
endpoint := c.getEndpoint(c.Config.ModelsPath)
38
39
c.printRequestDebugInfo(endpoint, nil, nil)
40
41
raw, err := c.Caller.Get(c.getEndpoint(c.Config.ModelsPath))
42
c.printResponseDebugInfo(raw)
43
44
if err != nil {
45
return nil, err
46
}
47
48
var response api.ListModelsResponse
49
if err := c.processResponse(raw, &response); err != nil {
50
return nil, err
51
}
52
53
sort.Slice(response.Data, func(i, j int) bool {
54
return response.Data[i].Id < response.Data[j].Id
55
})
56
57
for _, model := range response.Data {
58
if strings.HasPrefix(model.Id, gptPrefix) || strings.HasPrefix(model.Id, o1Prefix) {
59
if model.Id != c.Config.Model {
60
result = append(result, fmt.Sprintf("- %s", model.Id))
61
continue
62
}
63
result = append(result, fmt.Sprintf("* %s (current)", model.Id))
64
}
65
}
66
67
return result, nil
68
}
69
70
// Query sends a query to the API, returning the response as a string along with the token usage.
71
//
72
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
73
// The context allows for request scoping, timeouts, and cancellation handling.
74
//
75
// Returns the API response string, the number of tokens used, and an error if any issues occur.
76
// If the response contains choices, it decodes the JSON and returns the content of the first choice.
77
//
78
// Parameters:
79
// - ctx: A context.Context that controls request cancellation and deadlines.
80
// - input: The query string to send to the API.
81
//
82
// Returns:
83
// - string: The content of the first response choice from the API.
84
// - int: The total number of tokens used in the request.
85
// - error: An error if the request fails or the response is invalid.
86
func (c *Client) Query(ctx context.Context, input string) (string, int, error) {
87
c.prepareQuery(input)
88
89
body, err := c.createBody(ctx, false)
90
if err != nil {
91
return "", 0, err
92
}
93
94
endpoint := c.getChatEndpoint()
95
96
c.printRequestDebugInfo(endpoint, body, nil)
97
98
raw, err := c.Caller.Post(endpoint, body, false)
99
c.printResponseDebugInfo(raw)
100
101
if err != nil {
102
return "", 0, err
103
}
104
105
var (
106
response string
107
tokensUsed int
108
)
109
110
caps := GetCapabilities(c.Config.Model)
111
112
if caps.UsesResponsesAPI {
113
var res api.ResponsesResponse
114
if err := c.processResponse(raw, &res); err != nil {
115
return "", 0, err
116
}
117
tokensUsed = res.Usage.TotalTokens
118
119
for _, output := range res.Output {
120
if output.Type != messageType {
121
continue
122
}
123
for _, content := range output.Content {
124
if content.Type == outputTextType {
125
response = content.Text
126
break
127
}
128
}
129
}
130
131
if response == "" {
132
return "", tokensUsed, errors.New("no response returned")
133
}
134
} else {
135
var res api.CompletionsResponse
136
if err := c.processResponse(raw, &res); err != nil {
137
return "", 0, err
138
}
139
tokensUsed = res.Usage.TotalTokens
140
141
if len(res.Choices) == 0 {
142
return "", tokensUsed, errors.New("no responses returned")
143
}
144
145
var ok bool
146
response, ok = res.Choices[0].Message.Content.(string)
147
if !ok {
148
return "", tokensUsed, errors.New("response cannot be converted to a string")
149
}
150
}
151
152
c.updateHistory(response)
153
154
return response, tokensUsed, nil
155
}
156
157
// Stream sends a query to the API and processes the response as a stream.
158
//
159
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
160
// The context allows for request scoping, timeouts, and cancellation handling.
161
//
162
// The method creates a request body with the input and calls the API using the `Post` method.
163
// The actual processing of the streamed response is handled inside the `Post` method.
164
//
165
// Parameters:
166
// - ctx: A context.Context that controls request cancellation and deadlines.
167
// - input: The query string to send to the API.
168
//
169
// Returns:
170
// - error: An error if the request fails or the response is invalid.
171
func (c *Client) Stream(ctx context.Context, input string) error {
172
c.prepareQuery(input)
173
174
body, err := c.createBody(ctx, true)
175
if err != nil {
176
return err
177
}
178
179
endpoint := c.getChatEndpoint()
180
181
c.printRequestDebugInfo(endpoint, body, nil)
182
183
result, err := c.Caller.Post(endpoint, body, true)
184
if err != nil {
185
return err
186
}
187
188
c.updateHistory(string(result))
189
190
return nil
191
}
192
193
func (c *Client) addQuery(query string) {
194
message := api.Message{
195
Role: UserRole,
196
Content: query,
197
}
198
199
c.History = append(c.History, history.History{
200
Message: message,
201
Timestamp: c.timer.Now(),
202
})
203
c.truncateHistory()
204
}
205
206
func (c *Client) createBody(ctx context.Context, stream bool) ([]byte, error) {
207
caps := GetCapabilities(c.Config.Model)
208
209
if caps.IsRealtime {
210
return nil, fmt.Errorf(ErrRealTime, c.Config.Model)
211
}
212
213
if c.Config.Web && !caps.SupportsWebSearch {
214
return nil, fmt.Errorf(ErrWebSearch, c.Config.Model)
215
}
216
217
if caps.UsesResponsesAPI || c.Config.Web {
218
req, err := c.createResponsesRequest(ctx, stream)
219
if err != nil {
220
return nil, err
221
}
222
return json.Marshal(req)
223
}
224
225
req, err := c.createCompletionsRequest(ctx, stream)
226
if err != nil {
227
return nil, err
228
}
229
return json.Marshal(req)
230
}
231
232
func (c *Client) createCompletionsRequest(ctx context.Context, stream bool) (*api.CompletionsRequest, error) {
233
var messages []api.Message
234
caps := GetCapabilities(c.Config.Model)
235
236
for index, item := range c.History {
237
if caps.OmitFirstSystemMsg && index == 0 {
238
continue
239
}
240
messages = append(messages, item.Message)
241
}
242
243
messages, err := c.appendMediaMessages(ctx, messages)
244
if err != nil {
245
return nil, err
246
}
247
248
req := &api.CompletionsRequest{
249
Messages: messages,
250
Model: c.Config.Model,
251
MaxTokens: c.Config.MaxTokens,
252
FrequencyPenalty: c.Config.FrequencyPenalty,
253
PresencePenalty: c.Config.PresencePenalty,
254
Seed: c.Config.Seed,
255
Stream: stream,
256
}
257
258
if caps.SupportsTemperature {
259
req.Temperature = c.Config.Temperature
260
}
261
if caps.SupportsTopP {
262
req.TopP = c.Config.TopP
263
}
264
265
return req, nil
266
}
267
268
func (c *Client) createResponsesRequest(ctx context.Context, stream bool) (*api.ResponsesRequest, error) {
269
var messages []api.Message
270
caps := GetCapabilities(c.Config.Model)
271
272
for index, item := range c.History {
273
if caps.OmitFirstSystemMsg && index == 0 {
274
continue
275
}
276
messages = append(messages, item.Message)
277
}
278
279
messages, err := c.appendMediaMessages(ctx, messages)
280
if err != nil {
281
return nil, err
282
}
283
284
req := &api.ResponsesRequest{
285
Model: c.Config.Model,
286
Input: messages,
287
MaxOutputTokens: c.Config.MaxTokens,
288
Reasoning: api.Reasoning{
289
Effort: c.Config.Effort,
290
},
291
Stream: stream,
292
}
293
294
if caps.SupportsTemperature {
295
req.Temperature = c.Config.Temperature
296
}
297
if caps.SupportsTopP {
298
req.TopP = c.Config.TopP
299
}
300
301
if c.Config.Web {
302
req.Tools = append(req.Tools, api.Tool{
303
Type: "web_search",
304
SearchContextSize: c.Config.WebContextSize,
305
})
306
}
307
308
return req, nil
309
}
310
311
func (c *Client) getChatEndpoint() string {
312
caps := GetCapabilities(c.Config.Model)
313
314
var endpoint string
315
if caps.UsesResponsesAPI {
316
endpoint = c.getEndpoint(c.Config.ResponsesPath)
317
} else {
318
endpoint = c.getEndpoint(c.Config.CompletionsPath)
319
}
320
return endpoint
321
}
322
323
func (c *Client) getEndpoint(path string) string {
324
return c.Config.URL + path
325
}
326
327
func (c *Client) prepareQuery(input string) {
328
if c.Config.OmitHistory {
329
c.History = nil
330
c.addQuery(input)
331
return
332
}
333
334
c.initHistory()
335
c.addQuery(input)
336
}
337
338
func (c *Client) processResponse(raw []byte, v interface{}) error {
339
if raw == nil {
340
return errors.New(ErrEmptyResponse)
341
}
342
343
if err := json.Unmarshal(raw, v); err != nil {
344
return fmt.Errorf("failed to decode response: %w", err)
345
}
346
347
return nil
348
}
349
350
type ModelCapabilities struct {
351
SupportsTemperature bool
352
SupportsTopP bool
353
SupportsStreaming bool
354
SupportsWebSearch bool
355
UsesResponsesAPI bool
356
OmitFirstSystemMsg bool
357
IsRealtime bool
358
}
359
360
func GetCapabilities(model string) ModelCapabilities {
361
isSearch := strings.Contains(model, SearchModelPattern)
362
isGpt5 := strings.Contains(model, gpt5Pattern)
363
364
supportsTemp := !isSearch
365
supportsTopP := !isSearch && !isGpt5
366
367
return ModelCapabilities{
368
SupportsTemperature: supportsTemp,
369
SupportsTopP: supportsTopP,
370
SupportsStreaming: !strings.Contains(model, o1ProPattern),
371
UsesResponsesAPI: strings.Contains(model, o1ProPattern) || isGpt5,
372
OmitFirstSystemMsg: strings.HasPrefix(model, o1Prefix) && !strings.Contains(model, o1ProPattern),
373
IsRealtime: strings.Contains(model, realTimePattern),
374
SupportsWebSearch: isGpt5 && !isSearch,
375
}
376
}
377
378