Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/client.go
2649 views
1
package client
2
3
import (
4
"bytes"
5
"context"
6
"encoding/base64"
7
"encoding/json"
8
"errors"
9
"fmt"
10
"github.com/kardolus/chatgpt-cli/api"
11
"github.com/kardolus/chatgpt-cli/api/http"
12
"github.com/kardolus/chatgpt-cli/cmd/chatgpt/utils"
13
"github.com/kardolus/chatgpt-cli/config"
14
"github.com/kardolus/chatgpt-cli/internal"
15
"go.uber.org/zap"
16
"golang.org/x/text/cases"
17
"golang.org/x/text/language"
18
"io"
19
"mime/multipart"
20
"net/textproto"
21
"net/url"
22
"os"
23
"path/filepath"
24
"sort"
25
"strings"
26
"time"
27
"unicode/utf8"
28
29
"github.com/kardolus/chatgpt-cli/history"
30
stdhttp "net/http"
31
)
32
33
const (
34
AssistantRole = "assistant"
35
ErrEmptyResponse = "empty response"
36
ErrMissingMCPAPIKey = "the %s api key is not configured"
37
ErrUnsupportedProvider = "unsupported MCP provider"
38
ErrHistoryTracking = "history tracking needs to be enabled to use this feature"
39
MaxTokenBufferPercentage = 20
40
SystemRole = "system"
41
UserRole = "user"
42
FunctionRole = "function"
43
InteractiveThreadPrefix = "int_"
44
SearchModelPattern = "-search"
45
ApifyURL = "https://api.apify.com/v2/acts/"
46
ApifyPath = "/run-sync-get-dataset-items"
47
ApifyProxyConfig = "proxyConfiguration"
48
gptPrefix = "gpt"
49
o1Prefix = "o1"
50
o1ProPattern = "o1-pro"
51
gpt5Pattern = "gpt-5"
52
audioType = "input_audio"
53
imageURLType = "image_url"
54
messageType = "message"
55
outputTextType = "output_text"
56
imageContent = "data:%s;base64,%s"
57
httpScheme = "http"
58
httpsScheme = "https"
59
bufferSize = 512
60
)
61
62
type Timer interface {
63
Now() time.Time
64
}
65
66
type RealTime struct {
67
}
68
69
func (r *RealTime) Now() time.Time {
70
return time.Now()
71
}
72
73
type FileReader interface {
74
ReadFile(name string) ([]byte, error)
75
ReadBufferFromFile(file *os.File) ([]byte, error)
76
Open(name string) (*os.File, error)
77
}
78
79
type RealFileReader struct{}
80
81
func (r *RealFileReader) Open(name string) (*os.File, error) {
82
return os.Open(name)
83
}
84
85
func (r *RealFileReader) ReadFile(name string) ([]byte, error) {
86
return os.ReadFile(name)
87
}
88
89
func (r *RealFileReader) ReadBufferFromFile(file *os.File) ([]byte, error) {
90
buffer := make([]byte, bufferSize)
91
_, err := file.Read(buffer)
92
93
return buffer, err
94
}
95
96
type FileWriter interface {
97
Write(file *os.File, buf []byte) error
98
Create(name string) (*os.File, error)
99
}
100
101
type RealFileWriter struct{}
102
103
func (w *RealFileWriter) Create(name string) (*os.File, error) {
104
return os.Create(name)
105
}
106
107
func (r *RealFileWriter) Write(file *os.File, buf []byte) error {
108
_, err := file.Write(buf)
109
return err
110
}
111
112
type Client struct {
113
Config config.Config
114
History []history.History
115
caller http.Caller
116
historyStore history.Store
117
timer Timer
118
reader FileReader
119
writer FileWriter
120
}
121
122
func New(callerFactory http.CallerFactory, hs history.Store, t Timer, r FileReader, w FileWriter, cfg config.Config, interactiveMode bool) *Client {
123
caller := callerFactory(cfg)
124
125
if interactiveMode && cfg.AutoCreateNewThread {
126
hs.SetThread(internal.GenerateUniqueSlug(InteractiveThreadPrefix))
127
} else {
128
hs.SetThread(cfg.Thread)
129
}
130
131
return &Client{
132
Config: cfg,
133
caller: caller,
134
historyStore: hs,
135
timer: t,
136
reader: r,
137
writer: w,
138
}
139
}
140
141
func (c *Client) WithContextWindow(window int) *Client {
142
c.Config.ContextWindow = window
143
return c
144
}
145
146
func (c *Client) WithServiceURL(url string) *Client {
147
c.Config.URL = url
148
return c
149
}
150
151
// InjectMCPContext calls an MCP plugin (e.g. Apify) with the given parameters,
152
// retrieves the result, and adds it to the chat history as a function message.
153
// The result is formatted as a string and tagged with the function name.
154
func (c *Client) InjectMCPContext(mcp api.MCPRequest) error {
155
if c.Config.OmitHistory {
156
return errors.New(ErrHistoryTracking)
157
}
158
159
endpoint, headers, body, err := c.buildMCPRequest(mcp)
160
if err != nil {
161
return err
162
}
163
164
c.printRequestDebugInfo(endpoint, body, headers)
165
166
raw, err := c.caller.PostWithHeaders(endpoint, body, headers)
167
if err != nil {
168
return err
169
}
170
171
c.printResponseDebugInfo(raw)
172
173
formatted := formatMCPResponse(raw, mcp.Function)
174
175
c.initHistory()
176
c.History = append(c.History, history.History{
177
Message: api.Message{
178
Role: FunctionRole,
179
Name: strings.ReplaceAll(mcp.Function, "~", "-"),
180
Content: formatted,
181
},
182
Timestamp: c.timer.Now(),
183
})
184
c.truncateHistory()
185
186
return c.historyStore.Write(c.History)
187
}
188
189
// ListModels retrieves a list of all available models from the OpenAI API.
190
// The models are returned as a slice of strings, each entry representing a model ID.
191
// Models that have an ID starting with 'gpt' are included.
192
// The currently active model is marked with an asterisk (*) in the list.
193
// In case of an error during the retrieval or processing of the models,
194
// the method returns an error. If the API response is empty, an error is returned as well.
195
func (c *Client) ListModels() ([]string, error) {
196
var result []string
197
198
endpoint := c.getEndpoint(c.Config.ModelsPath)
199
200
c.printRequestDebugInfo(endpoint, nil, nil)
201
202
raw, err := c.caller.Get(c.getEndpoint(c.Config.ModelsPath))
203
c.printResponseDebugInfo(raw)
204
205
if err != nil {
206
return nil, err
207
}
208
209
var response api.ListModelsResponse
210
if err := c.processResponse(raw, &response); err != nil {
211
return nil, err
212
}
213
214
sort.Slice(response.Data, func(i, j int) bool {
215
return response.Data[i].Id < response.Data[j].Id
216
})
217
218
for _, model := range response.Data {
219
if strings.HasPrefix(model.Id, gptPrefix) || strings.HasPrefix(model.Id, o1Prefix) {
220
if model.Id != c.Config.Model {
221
result = append(result, fmt.Sprintf("- %s", model.Id))
222
continue
223
}
224
result = append(result, fmt.Sprintf("* %s (current)", model.Id))
225
}
226
}
227
228
return result, nil
229
}
230
231
// ProvideContext adds custom context to the client's history by converting the
232
// provided string into a series of messages. This allows the ChatGPT API to have
233
// prior knowledge of the provided context when generating responses.
234
//
235
// The context string should contain the text you want to provide as context,
236
// and the method will split it into messages, preserving punctuation and special
237
// characters.
238
func (c *Client) ProvideContext(context string) {
239
c.initHistory()
240
historyEntries := c.createHistoryEntriesFromString(context)
241
c.History = append(c.History, historyEntries...)
242
}
243
244
// Query sends a query to the API, returning the response as a string along with the token usage.
245
//
246
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
247
// The context allows for request scoping, timeouts, and cancellation handling.
248
//
249
// Returns the API response string, the number of tokens used, and an error if any issues occur.
250
// If the response contains choices, it decodes the JSON and returns the content of the first choice.
251
//
252
// Parameters:
253
// - ctx: A context.Context that controls request cancellation and deadlines.
254
// - input: The query string to send to the API.
255
//
256
// Returns:
257
// - string: The content of the first response choice from the API.
258
// - int: The total number of tokens used in the request.
259
// - error: An error if the request fails or the response is invalid.
260
func (c *Client) Query(ctx context.Context, input string) (string, int, error) {
261
c.prepareQuery(input)
262
263
body, err := c.createBody(ctx, false)
264
if err != nil {
265
return "", 0, err
266
}
267
268
endpoint := c.getChatEndpoint()
269
270
c.printRequestDebugInfo(endpoint, body, nil)
271
272
raw, err := c.caller.Post(endpoint, body, false)
273
c.printResponseDebugInfo(raw)
274
275
if err != nil {
276
return "", 0, err
277
}
278
279
var (
280
response string
281
tokensUsed int
282
)
283
284
caps := GetCapabilities(c.Config.Model)
285
286
if caps.UsesResponsesAPI {
287
var res api.ResponsesResponse
288
if err := c.processResponse(raw, &res); err != nil {
289
return "", 0, err
290
}
291
tokensUsed = res.Usage.TotalTokens
292
293
for _, output := range res.Output {
294
if output.Type != messageType {
295
continue
296
}
297
for _, content := range output.Content {
298
if content.Type == outputTextType {
299
response = content.Text
300
break
301
}
302
}
303
}
304
305
if response == "" {
306
return "", tokensUsed, errors.New("no response returned")
307
}
308
} else {
309
var res api.CompletionsResponse
310
if err := c.processResponse(raw, &res); err != nil {
311
return "", 0, err
312
}
313
tokensUsed = res.Usage.TotalTokens
314
315
if len(res.Choices) == 0 {
316
return "", tokensUsed, errors.New("no responses returned")
317
}
318
319
var ok bool
320
response, ok = res.Choices[0].Message.Content.(string)
321
if !ok {
322
return "", tokensUsed, errors.New("response cannot be converted to a string")
323
}
324
}
325
326
c.updateHistory(response)
327
328
return response, tokensUsed, nil
329
}
330
331
// Stream sends a query to the API and processes the response as a stream.
332
//
333
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
334
// The context allows for request scoping, timeouts, and cancellation handling.
335
//
336
// The method creates a request body with the input and calls the API using the `Post` method.
337
// The actual processing of the streamed response is handled inside the `Post` method.
338
//
339
// Parameters:
340
// - ctx: A context.Context that controls request cancellation and deadlines.
341
// - input: The query string to send to the API.
342
//
343
// Returns:
344
// - error: An error if the request fails or the response is invalid.
345
func (c *Client) Stream(ctx context.Context, input string) error {
346
c.prepareQuery(input)
347
348
body, err := c.createBody(ctx, true)
349
if err != nil {
350
return err
351
}
352
353
endpoint := c.getChatEndpoint()
354
355
c.printRequestDebugInfo(endpoint, body, nil)
356
357
result, err := c.caller.Post(endpoint, body, true)
358
if err != nil {
359
return err
360
}
361
362
c.updateHistory(string(result))
363
364
return nil
365
}
366
367
// SynthesizeSpeech converts the given input text into speech using the configured TTS model,
368
// and writes the resulting audio to the specified output file.
369
//
370
// The audio format is inferred from the output file's extension (e.g., "mp3", "wav") and sent
371
// as the "response_format" in the request to the OpenAI speech synthesis endpoint.
372
//
373
// Parameters:
374
// - inputText: The text to synthesize into speech.
375
// - outputPath: The path to the output audio file. The file extension determines the response format.
376
//
377
// Returns an error if the request fails, the response cannot be written, or the file cannot be created.
378
func (c *Client) SynthesizeSpeech(inputText, outputPath string) error {
379
req := api.Speech{
380
Model: c.Config.Model,
381
Voice: c.Config.Voice,
382
Input: inputText,
383
ResponseFormat: getExtension(outputPath),
384
}
385
return c.postAndWriteBinaryOutput(c.getEndpoint(c.Config.SpeechPath), req, outputPath, "binary", nil)
386
}
387
388
// GenerateImage sends a prompt to the configured image generation model (e.g., gpt-image-1)
389
// and writes the resulting image to the specified output path.
390
//
391
// The method performs the following steps:
392
// 1. Sends a POST request to the image generation endpoint with the provided prompt.
393
// 2. Parses the response and extracts the base64-encoded image data.
394
// 3. Decodes the image bytes and writes them to the given outputPath.
395
// 4. Logs the number of bytes written using debug output.
396
//
397
// Parameters:
398
// - inputText: The prompt describing the image to be generated.
399
// - outputPath: The file path where the generated image (e.g., .png) will be saved.
400
//
401
// Returns:
402
// - An error if any part of the request, decoding, or file writing fails.
403
func (c *Client) GenerateImage(inputText, outputPath string) error {
404
req := api.Draw{
405
Model: c.Config.Model,
406
Prompt: inputText,
407
}
408
409
return c.postAndWriteBinaryOutput(
410
c.getEndpoint(c.Config.ImageGenerationsPath),
411
req,
412
outputPath,
413
"image",
414
func(respBytes []byte) ([]byte, error) {
415
var response struct {
416
Data []struct {
417
B64 string `json:"b64_json"`
418
} `json:"data"`
419
}
420
if err := json.Unmarshal(respBytes, &response); err != nil {
421
return nil, fmt.Errorf("failed to decode response: %w", err)
422
}
423
if len(response.Data) == 0 {
424
return nil, fmt.Errorf("no image data returned")
425
}
426
decoded, err := base64.StdEncoding.DecodeString(response.Data[0].B64)
427
if err != nil {
428
return nil, fmt.Errorf("failed to decode base64 image: %w", err)
429
}
430
return decoded, nil
431
},
432
)
433
}
434
435
// EditImage edits an input image using a text prompt and writes the modified image to the specified output path.
436
//
437
// This method sends a multipart/form-data POST request to the image editing endpoint
438
// (typically OpenAI's /v1/images/edits). The request includes:
439
// - The image file to edit.
440
// - A text prompt describing how the image should be modified.
441
// - The model ID (e.g., gpt-image-1).
442
//
443
// The response is expected to contain a base64-encoded image, which is decoded and written to the outputPath.
444
//
445
// Parameters:
446
// - inputText: A text prompt describing the desired modifications to the image.
447
// - inputPath: The file path to the source image (must be a supported format: PNG, JPEG, or WebP).
448
// - outputPath: The file path where the edited image will be saved.
449
//
450
// Returns:
451
// - An error if any step of the process fails: reading the file, building the request, sending it,
452
// decoding the response, or writing the output image.
453
//
454
// Example:
455
//
456
// err := client.EditImage("Add a rainbow in the sky", "input.png", "output.png")
457
// if err != nil {
458
// log.Fatal(err)
459
// }
460
func (c *Client) EditImage(inputText, inputPath, outputPath string) error {
461
endpoint := c.getEndpoint(c.Config.ImageEditsPath)
462
463
file, err := c.reader.Open(inputPath)
464
if err != nil {
465
return fmt.Errorf("failed to open input image: %w", err)
466
}
467
defer file.Close()
468
469
var buf bytes.Buffer
470
writer := multipart.NewWriter(&buf)
471
472
mimeType, err := c.getMimeTypeFromFileContent(inputPath)
473
if err != nil {
474
return fmt.Errorf("failed to detect MIME type: %w", err)
475
}
476
if !strings.HasPrefix(mimeType, "image/") {
477
return fmt.Errorf("unsupported MIME type: %s", mimeType)
478
}
479
480
header := make(textproto.MIMEHeader)
481
header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="image"; filename="%s"`, filepath.Base(inputPath)))
482
header.Set("Content-Type", mimeType)
483
484
part, err := writer.CreatePart(header)
485
if err != nil {
486
return fmt.Errorf("failed to create image part: %w", err)
487
}
488
if _, err := io.Copy(part, file); err != nil {
489
return fmt.Errorf("failed to copy image data: %w", err)
490
}
491
492
if err := writer.WriteField("prompt", inputText); err != nil {
493
return fmt.Errorf("failed to add prompt: %w", err)
494
}
495
if err := writer.WriteField("model", c.Config.Model); err != nil {
496
return fmt.Errorf("failed to add model: %w", err)
497
}
498
499
if err := writer.Close(); err != nil {
500
return fmt.Errorf("failed to close multipart writer: %w", err)
501
}
502
503
c.printRequestDebugInfo(endpoint, buf.Bytes(), map[string]string{
504
"Content-Type": writer.FormDataContentType(),
505
})
506
507
respBytes, err := c.caller.PostWithHeaders(endpoint, buf.Bytes(), map[string]string{
508
c.Config.AuthHeader: fmt.Sprintf("%s %s", c.Config.AuthTokenPrefix, c.Config.APIKey),
509
"Content-Type": writer.FormDataContentType(),
510
})
511
if err != nil {
512
return fmt.Errorf("failed to edit image: %w", err)
513
}
514
515
// Parse the JSON and extract b64_json
516
var response struct {
517
Data []struct {
518
B64 string `json:"b64_json"`
519
} `json:"data"`
520
}
521
if err := json.Unmarshal(respBytes, &response); err != nil {
522
return fmt.Errorf("failed to decode response: %w", err)
523
}
524
if len(response.Data) == 0 {
525
return fmt.Errorf("no image data returned")
526
}
527
528
imgBytes, err := base64.StdEncoding.DecodeString(response.Data[0].B64)
529
if err != nil {
530
return fmt.Errorf("failed to decode base64 image: %w", err)
531
}
532
533
outFile, err := c.writer.Create(outputPath)
534
if err != nil {
535
return fmt.Errorf("failed to create output file: %w", err)
536
}
537
defer outFile.Close()
538
539
if err := c.writer.Write(outFile, imgBytes); err != nil {
540
return fmt.Errorf("failed to write image: %w", err)
541
}
542
543
c.printResponseDebugInfo([]byte(fmt.Sprintf("[image] %d bytes written to %s", len(imgBytes), outputPath)))
544
return nil
545
}
546
547
// Transcribe uploads an audio file to the OpenAI transcription endpoint and returns the transcribed text.
548
//
549
// It reads the audio file from the provided `audioPath`, creates a multipart/form-data request with the model name
550
// and the audio file, and sends it to the endpoint defined by the `TranscriptionsPath` in the client config.
551
// The method expects a JSON response containing a "text" field with the transcription result.
552
//
553
// Parameters:
554
// - audioPath: The local file path to the audio file to be transcribed.
555
//
556
// Returns:
557
// - string: The transcribed text from the audio file.
558
// - error: An error if the file can't be read, the request fails, or the response is invalid.
559
//
560
// This method supports formats like mp3, mp4, mpeg, mpga, m4a, wav, and webm, depending on API compatibility.
561
func (c *Client) Transcribe(audioPath string) (string, error) {
562
c.initHistory()
563
564
file, err := c.reader.Open(audioPath)
565
if err != nil {
566
return "", fmt.Errorf("failed to open audio file: %w", err)
567
}
568
defer file.Close()
569
570
var buf bytes.Buffer
571
writer := multipart.NewWriter(&buf)
572
573
_ = writer.WriteField("model", c.Config.Model)
574
575
part, err := writer.CreateFormFile("file", filepath.Base(audioPath))
576
if err != nil {
577
return "", err
578
}
579
if _, err := io.Copy(part, file); err != nil {
580
return "", err
581
}
582
583
if err := writer.Close(); err != nil {
584
return "", err
585
}
586
587
endpoint := c.getEndpoint(c.Config.TranscriptionsPath)
588
headers := map[string]string{
589
"Content-Type": writer.FormDataContentType(),
590
c.Config.AuthHeader: fmt.Sprintf("%s %s", c.Config.AuthTokenPrefix, c.Config.APIKey),
591
}
592
593
c.printRequestDebugInfo(endpoint, buf.Bytes(), headers)
594
595
raw, err := c.caller.PostWithHeaders(endpoint, buf.Bytes(), headers)
596
if err != nil {
597
return "", err
598
}
599
600
c.printResponseDebugInfo(raw)
601
602
var res struct {
603
Text string `json:"text"`
604
}
605
if err := json.Unmarshal(raw, &res); err != nil {
606
return "", fmt.Errorf("failed to parse transcription: %w", err)
607
}
608
609
c.History = append(c.History, history.History{
610
Message: api.Message{
611
Role: UserRole,
612
Content: fmt.Sprintf("[transcribe] %s", filepath.Base(audioPath)),
613
},
614
Timestamp: c.timer.Now(),
615
})
616
617
c.History = append(c.History, history.History{
618
Message: api.Message{
619
Role: AssistantRole,
620
Content: res.Text,
621
},
622
Timestamp: c.timer.Now(),
623
})
624
625
c.truncateHistory()
626
627
if !c.Config.OmitHistory {
628
_ = c.historyStore.Write(c.History)
629
}
630
631
return res.Text, nil
632
}
633
634
func (c *Client) appendMediaMessages(ctx context.Context, messages []api.Message) ([]api.Message, error) {
635
if data, ok := ctx.Value(internal.BinaryDataKey).([]byte); ok {
636
content, err := c.createImageContentFromBinary(data)
637
if err != nil {
638
return nil, err
639
}
640
messages = append(messages, api.Message{
641
Role: UserRole,
642
Content: []api.ImageContent{content},
643
})
644
} else if path, ok := ctx.Value(internal.ImagePathKey).(string); ok {
645
content, err := c.createImageContentFromURLOrFile(path)
646
if err != nil {
647
return nil, err
648
}
649
messages = append(messages, api.Message{
650
Role: UserRole,
651
Content: []api.ImageContent{content},
652
})
653
} else if path, ok := ctx.Value(internal.AudioPathKey).(string); ok {
654
content, err := c.createAudioContentFromFile(path)
655
if err != nil {
656
return nil, err
657
}
658
messages = append(messages, api.Message{
659
Role: UserRole,
660
Content: []api.AudioContent{content},
661
})
662
}
663
return messages, nil
664
}
665
666
func (c *Client) createBody(ctx context.Context, stream bool) ([]byte, error) {
667
caps := GetCapabilities(c.Config.Model)
668
669
if caps.UsesResponsesAPI {
670
req, err := c.createResponsesRequest(ctx, stream)
671
if err != nil {
672
return nil, err
673
}
674
return json.Marshal(req)
675
}
676
677
req, err := c.createCompletionsRequest(ctx, stream)
678
if err != nil {
679
return nil, err
680
}
681
return json.Marshal(req)
682
}
683
684
func (c *Client) createCompletionsRequest(ctx context.Context, stream bool) (*api.CompletionsRequest, error) {
685
var messages []api.Message
686
caps := GetCapabilities(c.Config.Model)
687
688
for index, item := range c.History {
689
if caps.OmitFirstSystemMsg && index == 0 {
690
continue
691
}
692
messages = append(messages, item.Message)
693
}
694
695
messages, err := c.appendMediaMessages(ctx, messages)
696
if err != nil {
697
return nil, err
698
}
699
700
req := &api.CompletionsRequest{
701
Messages: messages,
702
Model: c.Config.Model,
703
MaxTokens: c.Config.MaxTokens,
704
FrequencyPenalty: c.Config.FrequencyPenalty,
705
PresencePenalty: c.Config.PresencePenalty,
706
Seed: c.Config.Seed,
707
Stream: stream,
708
}
709
710
if caps.SupportsTemperature {
711
req.Temperature = c.Config.Temperature
712
req.TopP = c.Config.TopP
713
}
714
715
return req, nil
716
}
717
718
func (c *Client) createResponsesRequest(ctx context.Context, stream bool) (*api.ResponsesRequest, error) {
719
var messages []api.Message
720
caps := GetCapabilities(c.Config.Model)
721
722
for index, item := range c.History {
723
if caps.OmitFirstSystemMsg && index == 0 {
724
continue
725
}
726
messages = append(messages, item.Message)
727
}
728
729
messages, err := c.appendMediaMessages(ctx, messages)
730
if err != nil {
731
return nil, err
732
}
733
734
req := &api.ResponsesRequest{
735
Model: c.Config.Model,
736
Input: messages,
737
MaxOutputTokens: c.Config.MaxTokens,
738
Reasoning: api.Reasoning{
739
Effort: c.Config.Effort,
740
},
741
Stream: stream,
742
Temperature: c.Config.Temperature,
743
TopP: c.Config.TopP,
744
}
745
746
return req, nil
747
}
748
749
func (c *Client) createImageContentFromBinary(binary []byte) (api.ImageContent, error) {
750
mime, err := getMimeTypeFromBytes(binary)
751
if err != nil {
752
return api.ImageContent{}, err
753
}
754
755
encoded := base64.StdEncoding.EncodeToString(binary)
756
content := api.ImageContent{
757
Type: imageURLType,
758
ImageURL: struct {
759
URL string `json:"url"`
760
}{
761
URL: fmt.Sprintf(imageContent, mime, encoded),
762
},
763
}
764
765
return content, nil
766
}
767
768
func (c *Client) createAudioContentFromFile(audio string) (api.AudioContent, error) {
769
770
format, err := c.detectAudioFormat(audio)
771
if err != nil {
772
return api.AudioContent{}, err
773
}
774
775
encodedAudio, err := c.base64Encode(audio)
776
if err != nil {
777
return api.AudioContent{}, err
778
}
779
780
return api.AudioContent{
781
Type: audioType,
782
InputAudio: api.InputAudio{
783
Data: encodedAudio,
784
Format: format,
785
},
786
}, nil
787
}
788
789
func (c *Client) createImageContentFromURLOrFile(image string) (api.ImageContent, error) {
790
var content api.ImageContent
791
792
if isValidURL(image) {
793
content = api.ImageContent{
794
Type: imageURLType,
795
ImageURL: struct {
796
URL string `json:"url"`
797
}{
798
URL: image,
799
},
800
}
801
} else {
802
mime, err := c.getMimeTypeFromFileContent(image)
803
if err != nil {
804
return content, err
805
}
806
807
encodedImage, err := c.base64Encode(image)
808
if err != nil {
809
return content, err
810
}
811
812
content = api.ImageContent{
813
Type: imageURLType,
814
ImageURL: struct {
815
URL string `json:"url"`
816
}{
817
URL: fmt.Sprintf(imageContent, mime, encodedImage),
818
},
819
}
820
}
821
822
return content, nil
823
}
824
825
func (c *Client) initHistory() {
826
if len(c.History) != 0 {
827
return
828
}
829
830
if !c.Config.OmitHistory {
831
c.History, _ = c.historyStore.Read()
832
}
833
834
if len(c.History) == 0 {
835
c.History = []history.History{{
836
Message: api.Message{
837
Role: SystemRole,
838
},
839
Timestamp: c.timer.Now(),
840
}}
841
}
842
843
c.History[0].Content = c.Config.Role
844
}
845
846
func (c *Client) addQuery(query string) {
847
message := api.Message{
848
Role: UserRole,
849
Content: query,
850
}
851
852
c.History = append(c.History, history.History{
853
Message: message,
854
Timestamp: c.timer.Now(),
855
})
856
c.truncateHistory()
857
}
858
859
func (c *Client) getChatEndpoint() string {
860
caps := GetCapabilities(c.Config.Model)
861
862
var endpoint string
863
if caps.UsesResponsesAPI {
864
endpoint = c.getEndpoint(c.Config.ResponsesPath)
865
} else {
866
endpoint = c.getEndpoint(c.Config.CompletionsPath)
867
}
868
return endpoint
869
}
870
871
func (c *Client) getEndpoint(path string) string {
872
return c.Config.URL + path
873
}
874
875
func (c *Client) prepareQuery(input string) {
876
c.initHistory()
877
c.addQuery(input)
878
}
879
880
func (c *Client) processResponse(raw []byte, v interface{}) error {
881
if raw == nil {
882
return errors.New(ErrEmptyResponse)
883
}
884
885
if err := json.Unmarshal(raw, v); err != nil {
886
return fmt.Errorf("failed to decode response: %w", err)
887
}
888
889
return nil
890
}
891
892
func (c *Client) truncateHistory() {
893
tokens, rolling := countTokens(c.History)
894
effectiveTokenSize := calculateEffectiveContextWindow(c.Config.ContextWindow, MaxTokenBufferPercentage)
895
896
if tokens <= effectiveTokenSize {
897
return
898
}
899
900
var index int
901
var total int
902
diff := tokens - effectiveTokenSize
903
904
for i := 1; i < len(rolling); i++ {
905
total += rolling[i]
906
if total > diff {
907
index = i
908
break
909
}
910
}
911
912
c.History = append(c.History[:1], c.History[index+1:]...)
913
}
914
915
func (c *Client) updateHistory(response string) {
916
c.History = append(c.History, history.History{
917
Message: api.Message{
918
Role: AssistantRole,
919
Content: response,
920
},
921
Timestamp: c.timer.Now(),
922
})
923
924
if !c.Config.OmitHistory {
925
_ = c.historyStore.Write(c.History)
926
}
927
}
928
929
func (c *Client) base64Encode(path string) (string, error) {
930
imageData, err := c.reader.ReadFile(path)
931
if err != nil {
932
return "", err
933
}
934
935
return base64.StdEncoding.EncodeToString(imageData), nil
936
}
937
938
func (c *Client) createHistoryEntriesFromString(input string) []history.History {
939
var result []history.History
940
941
words := strings.Fields(input)
942
943
for i := 0; i < len(words); i += 100 {
944
end := i + 100
945
if end > len(words) {
946
end = len(words)
947
}
948
949
content := strings.Join(words[i:end], " ")
950
951
item := history.History{
952
Message: api.Message{
953
Role: UserRole,
954
Content: content,
955
},
956
Timestamp: c.timer.Now(),
957
}
958
result = append(result, item)
959
}
960
961
return result
962
}
963
964
func (c *Client) detectAudioFormat(path string) (string, error) {
965
file, err := c.reader.Open(path)
966
if err != nil {
967
return "", err
968
}
969
defer file.Close()
970
971
buf, err := c.reader.ReadBufferFromFile(file)
972
if err != nil {
973
return "", err
974
}
975
976
// WAV
977
if string(buf[0:4]) == "RIFF" && string(buf[8:12]) == "WAVE" {
978
return "wav", nil
979
}
980
981
// MP3 (ID3 or sync bits)
982
if string(buf[0:3]) == "ID3" || (buf[0] == 0xFF && (buf[1]&0xE0) == 0xE0) {
983
return "mp3", nil
984
}
985
986
// FLAC
987
if string(buf[0:4]) == "fLaC" {
988
return "flac", nil
989
}
990
991
// OGG
992
if string(buf[0:4]) == "OggS" {
993
return "ogg", nil
994
}
995
996
// M4A / MP4
997
if string(buf[4:8]) == "ftyp" {
998
if string(buf[8:12]) == "M4A " || string(buf[8:12]) == "isom" || string(buf[8:12]) == "mp42" {
999
return "m4a", nil
1000
}
1001
return "mp4", nil
1002
}
1003
1004
return "unknown", nil
1005
}
1006
1007
func (c *Client) getMimeTypeFromFileContent(path string) (string, error) {
1008
file, err := c.reader.Open(path)
1009
if err != nil {
1010
return "", err
1011
}
1012
defer file.Close()
1013
1014
buffer, err := c.reader.ReadBufferFromFile(file)
1015
if err != nil {
1016
return "", err
1017
}
1018
1019
mimeType := stdhttp.DetectContentType(buffer)
1020
1021
return mimeType, nil
1022
}
1023
1024
func (c *Client) printRequestDebugInfo(endpoint string, body []byte, headers map[string]string) {
1025
sugar := zap.S()
1026
sugar.Debugf("\nGenerated cURL command:\n")
1027
1028
method := "POST"
1029
if body == nil {
1030
method = "GET"
1031
}
1032
sugar.Debugf("curl --location --insecure --request %s '%s' \\", method, endpoint)
1033
1034
if len(headers) > 0 {
1035
for k, v := range headers {
1036
sugar.Debugf(" --header '%s: %s' \\", k, v)
1037
}
1038
} else {
1039
sugar.Debugf(" --header \"Authorization: Bearer ${%s_API_KEY}\" \\", strings.ToUpper(c.Config.Name))
1040
sugar.Debugf(" --header 'Content-Type: application/json' \\")
1041
}
1042
1043
if body != nil {
1044
bodyString := strings.ReplaceAll(string(body), "'", "'\"'\"'")
1045
sugar.Debugf(" --data-raw '%s'", bodyString)
1046
}
1047
}
1048
1049
func (c *Client) printResponseDebugInfo(raw []byte) {
1050
sugar := zap.S()
1051
sugar.Debugf("\nResponse\n")
1052
sugar.Debugf("%s\n", raw)
1053
}
1054
1055
func (c *Client) postAndWriteBinaryOutput(endpoint string, requestBody interface{}, outputPath, debugLabel string, transform func([]byte) ([]byte, error)) error {
1056
body, err := json.Marshal(requestBody)
1057
if err != nil {
1058
return fmt.Errorf("failed to marshal request: %w", err)
1059
}
1060
1061
c.printRequestDebugInfo(endpoint, body, nil)
1062
1063
respBytes, err := c.caller.Post(endpoint, body, false)
1064
if err != nil {
1065
return fmt.Errorf("API request failed: %w", err)
1066
}
1067
1068
if transform != nil {
1069
respBytes, err = transform(respBytes)
1070
if err != nil {
1071
return err
1072
}
1073
}
1074
1075
outFile, err := c.writer.Create(outputPath)
1076
if err != nil {
1077
return fmt.Errorf("failed to create output file: %w", err)
1078
}
1079
defer outFile.Close()
1080
1081
if err := c.writer.Write(outFile, respBytes); err != nil {
1082
return fmt.Errorf("failed to write %s: %w", debugLabel, err)
1083
}
1084
1085
c.printResponseDebugInfo([]byte(fmt.Sprintf("[%s] %d bytes written to %s", debugLabel, len(respBytes), outputPath)))
1086
return nil
1087
}
1088
1089
func (c *Client) buildMCPRequest(mcp api.MCPRequest) (string, map[string]string, []byte, error) {
1090
mcp.Provider = strings.ToLower(mcp.Provider)
1091
params := mcp.Params
1092
1093
if mcp.Provider != utils.ApifyProvider {
1094
return "", nil, nil, errors.New(ErrUnsupportedProvider)
1095
}
1096
1097
apiKey := c.Config.ApifyAPIKey
1098
if apiKey == "" {
1099
return "", nil, nil, fmt.Errorf(ErrMissingMCPAPIKey, mcp.Provider)
1100
}
1101
1102
params[ApifyProxyConfig] = api.ProxyConfiguration{UseApifyProxy: true}
1103
endpoint := ApifyURL + mcp.Function + ApifyPath
1104
1105
headers := map[string]string{
1106
"Content-Type": "application/json",
1107
"Authorization": fmt.Sprintf("Bearer %s", apiKey),
1108
}
1109
1110
body, err := json.Marshal(params)
1111
if err != nil {
1112
return "", nil, nil, fmt.Errorf("failed to marshal request: %w", err)
1113
}
1114
1115
return endpoint, headers, body, nil
1116
}
1117
1118
type ModelCapabilities struct {
1119
SupportsTemperature bool
1120
SupportsStreaming bool
1121
UsesResponsesAPI bool
1122
OmitFirstSystemMsg bool
1123
}
1124
1125
func GetCapabilities(model string) ModelCapabilities {
1126
return ModelCapabilities{
1127
SupportsTemperature: !strings.Contains(model, SearchModelPattern),
1128
SupportsStreaming: !strings.Contains(model, o1ProPattern),
1129
UsesResponsesAPI: strings.Contains(model, o1ProPattern) || strings.Contains(model, gpt5Pattern),
1130
OmitFirstSystemMsg: strings.HasPrefix(model, o1Prefix) && !strings.Contains(model, o1ProPattern),
1131
}
1132
}
1133
1134
func formatMCPResponse(raw []byte, function string) string {
1135
var result interface{}
1136
if err := json.Unmarshal(raw, &result); err != nil {
1137
return fmt.Sprintf("[MCP: %s] (failed to decode response)", function)
1138
}
1139
1140
var lines []string
1141
1142
switch v := result.(type) {
1143
case []interface{}:
1144
if len(v) == 0 {
1145
return fmt.Sprintf("[MCP: %s] (no data returned)", function)
1146
}
1147
if obj, ok := v[0].(map[string]interface{}); ok {
1148
lines = formatKeyValues(obj)
1149
} else {
1150
return fmt.Sprintf("[MCP: %s] (unexpected response format)", function)
1151
}
1152
case map[string]interface{}:
1153
lines = formatKeyValues(v)
1154
default:
1155
return fmt.Sprintf("[MCP: %s] (unexpected response format)", function)
1156
}
1157
1158
sort.Strings(lines)
1159
return fmt.Sprintf("[MCP: %s]\n%s", function, strings.Join(lines, "\n"))
1160
}
1161
1162
func formatKeyValues(obj map[string]interface{}) []string {
1163
var lines []string
1164
caser := cases.Title(language.English)
1165
for k, val := range obj {
1166
label := caser.String(strings.ReplaceAll(k, "_", " "))
1167
lines = append(lines, fmt.Sprintf("%s: %v", label, val))
1168
}
1169
return lines
1170
}
1171
1172
func calculateEffectiveContextWindow(window int, bufferPercentage int) int {
1173
adjustedPercentage := 100 - bufferPercentage
1174
effectiveContextWindow := (window * adjustedPercentage) / 100
1175
return effectiveContextWindow
1176
}
1177
1178
func countTokens(entries []history.History) (int, []int) {
1179
var result int
1180
var rolling []int
1181
1182
for _, entry := range entries {
1183
charCount, wordCount := 0, 0
1184
words := strings.Fields(entry.Content.(string))
1185
wordCount += len(words)
1186
1187
for _, word := range words {
1188
charCount += utf8.RuneCountInString(word)
1189
}
1190
1191
// This is a simple approximation; actual token count may differ.
1192
// You can adjust this based on your language and the specific tokenizer used by the model.
1193
tokenCountForMessage := (charCount + wordCount) / 2
1194
result += tokenCountForMessage
1195
rolling = append(rolling, tokenCountForMessage)
1196
}
1197
1198
return result, rolling
1199
}
1200
1201
func getExtension(path string) string {
1202
ext := filepath.Ext(path) // e.g. ".mp4"
1203
if ext != "" {
1204
return strings.TrimPrefix(ext, ".") // "mp4"
1205
}
1206
return ""
1207
}
1208
1209
func getMimeTypeFromBytes(data []byte) (string, error) {
1210
mimeType := stdhttp.DetectContentType(data)
1211
1212
return mimeType, nil
1213
}
1214
1215
func isValidURL(input string) bool {
1216
parsedURL, err := url.ParseRequestURI(input)
1217
if err != nil {
1218
return false
1219
}
1220
1221
// Ensure that the URL has a valid scheme
1222
schemes := []string{httpScheme, httpsScheme}
1223
for _, scheme := range schemes {
1224
if strings.HasPrefix(parsedURL.Scheme, scheme) {
1225
return true
1226
}
1227
}
1228
1229
return false
1230
}
1231
1232