Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/client_test.go
2649 views
1
package client_test
2
3
import (
4
"bytes"
5
"context"
6
"encoding/base64"
7
"encoding/json"
8
"errors"
9
"fmt"
10
"github.com/golang/mock/gomock"
11
_ "github.com/golang/mock/mockgen/model"
12
"github.com/kardolus/chatgpt-cli/api"
13
"github.com/kardolus/chatgpt-cli/api/client"
14
"github.com/kardolus/chatgpt-cli/api/http"
15
"github.com/kardolus/chatgpt-cli/cmd/chatgpt/utils"
16
config2 "github.com/kardolus/chatgpt-cli/config"
17
"github.com/kardolus/chatgpt-cli/history"
18
"github.com/kardolus/chatgpt-cli/internal"
19
"github.com/kardolus/chatgpt-cli/test"
20
"io"
21
"os"
22
"strings"
23
"testing"
24
"time"
25
26
. "github.com/onsi/gomega"
27
"github.com/sclevine/spec"
28
"github.com/sclevine/spec/report"
29
)
30
31
//go:generate mockgen -destination=callermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/api/http Caller
32
//go:generate mockgen -destination=historymocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/history Store
33
//go:generate mockgen -destination=timermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/api/client Timer
34
//go:generate mockgen -destination=readermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/api/client FileReader
35
//go:generate mockgen -destination=writermocks_test.go -package=client_test github.com/kardolus/chatgpt-cli/api/client FileWriter
36
37
const (
38
envApiKey = "api-key"
39
commandLineMode = false
40
interactiveMode = true
41
)
42
43
var (
44
mockCtrl *gomock.Controller
45
mockCaller *MockCaller
46
mockHistoryStore *MockStore
47
mockTimer *MockTimer
48
mockReader *MockFileReader
49
mockWriter *MockFileWriter
50
factory *clientFactory
51
apiKeyEnvVar string
52
config config2.Config
53
)
54
55
func TestUnitClient(t *testing.T) {
56
spec.Run(t, "Testing the client package", testClient, spec.Report(report.Terminal{}))
57
}
58
59
func testClient(t *testing.T, when spec.G, it spec.S) {
60
const query = "test query"
61
62
it.Before(func() {
63
RegisterTestingT(t)
64
mockCtrl = gomock.NewController(t)
65
mockCaller = NewMockCaller(mockCtrl)
66
mockHistoryStore = NewMockStore(mockCtrl)
67
mockTimer = NewMockTimer(mockCtrl)
68
mockReader = NewMockFileReader(mockCtrl)
69
mockWriter = NewMockFileWriter(mockCtrl)
70
config = MockConfig()
71
72
factory = newClientFactory(mockHistoryStore)
73
74
apiKeyEnvVar = strings.ToUpper(config.Name) + "_API_KEY"
75
Expect(os.Setenv(apiKeyEnvVar, envApiKey)).To(Succeed())
76
})
77
78
it.After(func() {
79
mockCtrl.Finish()
80
})
81
82
when("New()", func() {
83
it("should set a unique thread slug in interactive mode when AutoCreateNewThread is true", func() {
84
var capturedThread string
85
mockHistoryStore.EXPECT().SetThread(gomock.Any()).DoAndReturn(func(thread string) {
86
capturedThread = thread
87
}).Times(1)
88
89
client.New(mockCallerFactory, mockHistoryStore, mockTimer, mockReader, mockWriter, MockConfig(), interactiveMode)
90
91
Expect(capturedThread).To(HavePrefix(client.InteractiveThreadPrefix))
92
Expect(len(capturedThread)).To(Equal(8)) // "int_" (4 chars) + 4 random characters
93
})
94
it("should not overwrite the thread in interactive mode when AutoCreateNewThread is false", func() {
95
var capturedThread string
96
mockHistoryStore.EXPECT().SetThread(gomock.Any()).DoAndReturn(func(thread string) {
97
capturedThread = thread
98
}).Times(1)
99
100
cfg := MockConfig()
101
cfg.AutoCreateNewThread = false
102
103
client.New(mockCallerFactory, mockHistoryStore, mockTimer, mockReader, mockWriter, cfg, interactiveMode)
104
105
Expect(capturedThread).To(Equal(config.Thread))
106
})
107
it("should never overwrite the thread in non-interactive mode", func() {
108
var capturedThread string
109
mockHistoryStore.EXPECT().SetThread(config.Thread).DoAndReturn(func(thread string) {
110
capturedThread = thread
111
}).Times(1)
112
113
client.New(mockCallerFactory, mockHistoryStore, mockTimer, mockReader, mockWriter, MockConfig(), commandLineMode)
114
115
Expect(capturedThread).To(Equal(config.Thread))
116
})
117
})
118
when("Query()", func() {
119
var (
120
body []byte
121
messages []api.Message
122
err error
123
)
124
125
type TestCase struct {
126
description string
127
setupPostReturn func() ([]byte, error)
128
postError error
129
expectedError string
130
}
131
132
tests := []TestCase{
133
{
134
description: "throws an error when the http callout fails",
135
setupPostReturn: func() ([]byte, error) { return nil, nil },
136
postError: errors.New("error message"),
137
expectedError: "error message",
138
},
139
{
140
description: "throws an error when the response is empty",
141
setupPostReturn: func() ([]byte, error) { return nil, nil },
142
postError: nil,
143
expectedError: "empty response",
144
},
145
{
146
description: "throws an error when the response is a malformed json",
147
setupPostReturn: func() ([]byte, error) {
148
malformed := `{"invalid":"json"` // missing closing brace
149
return []byte(malformed), nil
150
},
151
postError: nil,
152
expectedError: "failed to decode response:",
153
},
154
{
155
description: "throws an error when the response is missing Choices",
156
setupPostReturn: func() ([]byte, error) {
157
response := &api.CompletionsResponse{
158
ID: "id",
159
Object: "object",
160
Created: 0,
161
Model: "model",
162
Choices: []api.Choice{},
163
}
164
165
respBytes, err := json.Marshal(response)
166
return respBytes, err
167
},
168
postError: nil,
169
expectedError: "no responses returned",
170
},
171
{
172
description: "throws an error when the response cannot be casted to a string",
173
setupPostReturn: func() ([]byte, error) {
174
response := &api.CompletionsResponse{
175
ID: "id",
176
Object: "object",
177
Created: 0,
178
Model: "model",
179
Choices: []api.Choice{
180
{
181
Message: api.Message{
182
Role: client.AssistantRole,
183
Content: 123, // cannot be converted to a string
184
},
185
FinishReason: "",
186
Index: 0,
187
},
188
},
189
}
190
191
respBytes, err := json.Marshal(response)
192
return respBytes, err
193
},
194
postError: nil,
195
expectedError: "response cannot be converted to a string",
196
},
197
}
198
199
for _, tt := range tests {
200
it(tt.description, func() {
201
factory.withoutHistory()
202
subject := factory.buildClientWithoutConfig()
203
204
messages = createMessages(nil, query)
205
body, err = createBody(messages, false)
206
Expect(err).NotTo(HaveOccurred())
207
208
respBytes, err := tt.setupPostReturn()
209
Expect(err).NotTo(HaveOccurred())
210
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, body, false).Return(respBytes, tt.postError)
211
212
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
213
214
_, _, err = subject.Query(context.Background(), query)
215
Expect(err).To(HaveOccurred())
216
Expect(err.Error()).To(ContainSubstring(tt.expectedError))
217
})
218
}
219
220
when("a valid http response is received", func() {
221
testValidHTTPResponse := func(subject *client.Client, expectedBody []byte, omitHistory bool) {
222
const (
223
answer = "content"
224
tokens = 789
225
)
226
227
choice := api.Choice{
228
Message: api.Message{
229
Role: client.AssistantRole,
230
Content: answer,
231
},
232
FinishReason: "",
233
Index: 0,
234
}
235
response := &api.CompletionsResponse{
236
ID: "id",
237
Object: "object",
238
Created: 0,
239
Model: subject.Config.Model,
240
Choices: []api.Choice{choice},
241
Usage: api.Usage{
242
PromptTokens: 123,
243
CompletionTokens: 456,
244
TotalTokens: tokens,
245
},
246
}
247
248
respBytes, err := json.Marshal(response)
249
Expect(err).NotTo(HaveOccurred())
250
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(respBytes, nil)
251
252
var request api.CompletionsRequest
253
err = json.Unmarshal(expectedBody, &request)
254
Expect(err).NotTo(HaveOccurred())
255
256
mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()
257
258
var h []history.History
259
if !omitHistory {
260
for _, msg := range request.Messages {
261
h = append(h, history.History{
262
Message: msg,
263
})
264
}
265
266
mockHistoryStore.EXPECT().Write(append(h, history.History{
267
Message: api.Message{
268
Role: client.AssistantRole,
269
Content: answer,
270
},
271
}))
272
}
273
274
result, usage, err := subject.Query(context.Background(), query)
275
Expect(err).NotTo(HaveOccurred())
276
Expect(result).To(Equal(answer))
277
Expect(usage).To(Equal(tokens))
278
}
279
it("returns the expected result for a non-empty history", func() {
280
h := []history.History{
281
{
282
Message: api.Message{
283
Role: client.SystemRole,
284
Content: config.Role,
285
},
286
},
287
{
288
Message: api.Message{
289
Role: client.UserRole,
290
Content: "question 1",
291
},
292
},
293
{
294
Message: api.Message{
295
Role: client.AssistantRole,
296
Content: "answer 1",
297
},
298
},
299
}
300
301
messages = createMessages(h, query)
302
factory.withHistory(h)
303
subject := factory.buildClientWithoutConfig()
304
305
body, err = createBody(messages, false)
306
Expect(err).NotTo(HaveOccurred())
307
308
testValidHTTPResponse(subject, body, false)
309
})
310
it("ignores history when configured to do so", func() {
311
mockHistoryStore.EXPECT().SetThread(config.Thread).Times(1)
312
313
config := MockConfig()
314
config.OmitHistory = true
315
316
subject := client.New(mockCallerFactory, mockHistoryStore, mockTimer, mockReader, mockWriter, config, commandLineMode)
317
Expect(err).NotTo(HaveOccurred())
318
319
// Read and Write are never called on the history store
320
mockHistoryStore.EXPECT().Read().Times(0)
321
mockHistoryStore.EXPECT().Write(gomock.Any()).Times(0)
322
323
messages = createMessages(nil, query)
324
325
body, err = createBody(messages, false)
326
Expect(err).NotTo(HaveOccurred())
327
328
testValidHTTPResponse(subject, body, true)
329
})
330
it("truncates the history as expected", func() {
331
hs := []history.History{
332
{
333
Message: api.Message{
334
Role: client.SystemRole,
335
Content: config.Role,
336
},
337
Timestamp: time.Time{},
338
},
339
{
340
Message: api.Message{
341
Role: client.UserRole,
342
Content: "question 1",
343
},
344
Timestamp: time.Time{},
345
},
346
{
347
Message: api.Message{
348
Role: client.AssistantRole,
349
Content: "answer 1",
350
},
351
Timestamp: time.Time{},
352
},
353
{
354
Message: api.Message{
355
Role: client.UserRole,
356
Content: "question 2",
357
},
358
Timestamp: time.Time{},
359
},
360
{
361
Message: api.Message{
362
Role: client.AssistantRole,
363
Content: "answer 2",
364
},
365
Timestamp: time.Time{},
366
},
367
{
368
Message: api.Message{
369
Role: client.UserRole,
370
Content: "question 3",
371
},
372
Timestamp: time.Time{},
373
},
374
{
375
Message: api.Message{
376
Role: client.AssistantRole,
377
Content: "answer 3",
378
},
379
Timestamp: time.Time{},
380
},
381
}
382
383
messages = createMessages(hs, query)
384
385
factory.withHistory(hs)
386
subject := factory.buildClientWithoutConfig()
387
388
// messages get truncated. Index 1+2 are cut out
389
messages = append(messages[:1], messages[3:]...)
390
391
body, err = createBody(messages, false)
392
Expect(err).NotTo(HaveOccurred())
393
394
testValidHTTPResponse(subject, body, false)
395
})
396
it("should skip the first message when the model starts with o1Prefix", func() {
397
factory.withHistory([]history.History{
398
{Message: api.Message{Role: client.SystemRole, Content: "First message"}},
399
{Message: api.Message{Role: client.UserRole, Content: "Second message"}},
400
})
401
402
o1Model := "o1-example-model"
403
config.Model = o1Model
404
405
subject := factory.buildClientWithoutConfig()
406
subject.Config.Model = o1Model
407
408
expectedBody, err := createBody([]api.Message{
409
{Role: client.UserRole, Content: "Second message"},
410
{Role: client.UserRole, Content: "test query"},
411
}, false)
412
Expect(err).NotTo(HaveOccurred())
413
414
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
415
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)
416
417
_, _, _ = subject.Query(context.Background(), "test query")
418
})
419
it("should include all messages when the model does not start with o1Prefix", func() {
420
const systemRole = "System role for this test"
421
422
factory.withHistory([]history.History{
423
{Message: api.Message{Role: client.SystemRole, Content: systemRole}},
424
{Message: api.Message{Role: client.UserRole, Content: "Second message"}},
425
})
426
427
regularModel := "gpt-4o"
428
config.Model = regularModel
429
430
subject := factory.buildClientWithoutConfig()
431
subject.Config.Model = regularModel
432
subject.Config.Role = systemRole
433
434
expectedBody, err := createBody([]api.Message{
435
{Role: client.SystemRole, Content: systemRole},
436
{Role: client.UserRole, Content: "Second message"},
437
{Role: client.UserRole, Content: "test query"},
438
}, false)
439
Expect(err).NotTo(HaveOccurred())
440
441
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
442
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)
443
444
_, _, _ = subject.Query(context.Background(), "test query")
445
})
446
it("should omit Temperature and TopP when the model matches SearchModelPattern", func() {
447
searchModel := "gpt-4o-search-preview"
448
config.Model = searchModel
449
config.Role = "role for search test"
450
451
factory.withHistory([]history.History{
452
{Message: api.Message{Role: client.SystemRole, Content: config.Role}},
453
})
454
455
subject := factory.buildClientWithoutConfig()
456
subject.Config.Model = searchModel
457
458
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
459
460
mockCaller.EXPECT().
461
Post(gomock.Any(), gomock.Any(), false).
462
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
463
var req map[string]interface{}
464
Expect(json.Unmarshal(body, &req)).To(Succeed())
465
466
// Should not include Temperature or TopP
467
Expect(req).NotTo(HaveKey("temperature"))
468
Expect(req).NotTo(HaveKey("top_p"))
469
470
return nil, nil
471
})
472
473
_, _, _ = subject.Query(context.Background(), "test query")
474
})
475
it("should include Temperature and TopP when the model does not match SearchModelPattern", func() {
476
regularModel := "gpt-4o"
477
config.Model = regularModel
478
config.Role = "regular model test"
479
480
factory.withHistory([]history.History{
481
{Message: api.Message{Role: client.SystemRole, Content: config.Role}},
482
})
483
484
subject := factory.buildClientWithoutConfig()
485
subject.Config.Model = regularModel
486
487
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
488
489
mockCaller.EXPECT().
490
Post(gomock.Any(), gomock.Any(), false).
491
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
492
var req map[string]interface{}
493
Expect(json.Unmarshal(body, &req)).To(Succeed())
494
495
Expect(req).To(HaveKeyWithValue("temperature", BeNumerically("==", config.Temperature)))
496
Expect(req).To(HaveKeyWithValue("top_p", BeNumerically("==", config.TopP)))
497
498
return nil, nil
499
})
500
501
_, _, _ = subject.Query(context.Background(), "test query")
502
})
503
})
504
505
when("an image is provided", func() {
506
const (
507
query = "test query"
508
systemRole = "System role for this test"
509
errorMessage = "error message"
510
image = "path/to/image.wrong"
511
website = "https://website.com"
512
)
513
514
it.Before(func() {
515
factory.withoutHistory()
516
})
517
518
it("should update a callout as expected when a valid image URL is provided", func() {
519
subject := factory.buildClientWithoutConfig()
520
521
subject.Config.Role = systemRole
522
523
ctx := context.Background()
524
ctx = context.WithValue(ctx, internal.ImagePathKey, website)
525
526
expectedBody, err := createBody([]api.Message{
527
{Role: client.SystemRole, Content: systemRole},
528
{Role: client.UserRole, Content: query},
529
{Role: client.UserRole, Content: []api.ImageContent{{
530
Type: "image_url",
531
ImageURL: struct {
532
URL string `json:"url"`
533
}{
534
URL: website,
535
},
536
}}},
537
}, false)
538
Expect(err).NotTo(HaveOccurred())
539
540
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
541
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)
542
543
_, _, _ = subject.Query(ctx, query)
544
})
545
it("throws an error when the image mime type cannot be obtained due to an open-error", func() {
546
subject := factory.buildClientWithoutConfig()
547
subject.Config.Role = systemRole
548
549
ctx := context.Background()
550
ctx = context.WithValue(ctx, internal.ImagePathKey, image)
551
552
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
553
mockReader.EXPECT().Open(image).Return(nil, errors.New(errorMessage))
554
555
_, _, err := subject.Query(ctx, query)
556
Expect(err).To(HaveOccurred())
557
Expect(err.Error()).To(Equal(errorMessage))
558
})
559
it("throws an error when the image mime type cannot be obtained due to a read-error", func() {
560
imageFile := &os.File{}
561
562
subject := factory.buildClientWithoutConfig()
563
subject.Config.Role = systemRole
564
565
ctx := context.Background()
566
ctx = context.WithValue(ctx, internal.ImagePathKey, image)
567
568
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
569
mockReader.EXPECT().Open(image).Return(imageFile, nil)
570
mockReader.EXPECT().ReadBufferFromFile(imageFile).Return(nil, errors.New(errorMessage))
571
572
_, _, err := subject.Query(ctx, query)
573
Expect(err).To(HaveOccurred())
574
Expect(err.Error()).To(Equal(errorMessage))
575
})
576
it("throws an error when the image base64 encoded content cannot be obtained due to a read-error", func() {
577
imageFile := &os.File{}
578
579
subject := factory.buildClientWithoutConfig()
580
subject.Config.Role = systemRole
581
582
ctx := context.Background()
583
ctx = context.WithValue(ctx, internal.ImagePathKey, image)
584
585
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
586
mockReader.EXPECT().Open(image).Return(imageFile, nil)
587
mockReader.EXPECT().ReadBufferFromFile(imageFile).Return(nil, nil)
588
mockReader.EXPECT().ReadFile(image).Return(nil, errors.New(errorMessage))
589
590
_, _, err := subject.Query(ctx, query)
591
Expect(err).To(HaveOccurred())
592
Expect(err.Error()).To(Equal(errorMessage))
593
})
594
it("should update a callout as expected when a valid local image is provided", func() {
595
imageFile := &os.File{}
596
597
subject := factory.buildClientWithoutConfig()
598
subject.Config.Role = systemRole
599
600
ctx := context.Background()
601
ctx = context.WithValue(ctx, internal.ImagePathKey, image)
602
603
mockReader.EXPECT().Open(image).Return(imageFile, nil)
604
mockReader.EXPECT().ReadBufferFromFile(imageFile).Return(nil, nil)
605
mockReader.EXPECT().ReadFile(image).Return(nil, nil)
606
607
expectedBody, err := createBody([]api.Message{
608
{Role: client.SystemRole, Content: systemRole},
609
{Role: client.UserRole, Content: query},
610
{Role: client.UserRole, Content: []api.ImageContent{{
611
Type: "image_url",
612
ImageURL: struct {
613
URL string `json:"url"`
614
}{
615
URL: "data:text/plain; charset=utf-8;base64,",
616
},
617
}}},
618
}, false)
619
Expect(err).NotTo(HaveOccurred())
620
621
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
622
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)
623
624
_, _, _ = subject.Query(ctx, query)
625
})
626
})
627
628
when("an audio file is provided", func() {
629
const (
630
query = "transcribe this"
631
systemRole = "System role for audio test"
632
errorMessage = "error opening audio file"
633
audio = "path/to/audio.wav"
634
)
635
636
it.Before(func() {
637
factory.withoutHistory()
638
})
639
640
it("throws an error when the audio file cannot be opened", func() {
641
subject := factory.buildClientWithoutConfig()
642
subject.Config.Role = systemRole
643
644
ctx := context.Background()
645
ctx = context.WithValue(ctx, internal.AudioPathKey, audio)
646
647
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
648
mockReader.EXPECT().Open(audio).Return(nil, errors.New(errorMessage))
649
650
_, _, err := subject.Query(ctx, query)
651
Expect(err).To(HaveOccurred())
652
Expect(err.Error()).To(Equal(errorMessage))
653
})
654
655
it("throws an error when the audio data cannot be read", func() {
656
audioFile := &os.File{}
657
subject := factory.buildClientWithoutConfig()
658
subject.Config.Role = systemRole
659
660
ctx := context.Background()
661
ctx = context.WithValue(ctx, internal.AudioPathKey, audio)
662
663
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
664
mockReader.EXPECT().Open(audio).Return(audioFile, nil)
665
mockReader.EXPECT().ReadBufferFromFile(audioFile).Return([]byte("RIFFxxxxWAVE..."), nil)
666
mockReader.EXPECT().ReadFile(audio).Return(nil, errors.New(errorMessage))
667
668
_, _, err := subject.Query(ctx, query)
669
Expect(err).To(HaveOccurred())
670
Expect(err.Error()).To(Equal(errorMessage))
671
})
672
673
it("adds audio as input_audio type content when valid", func() {
674
audioFile := &os.File{}
675
subject := factory.buildClientWithoutConfig()
676
subject.Config.Role = systemRole
677
678
ctx := context.Background()
679
ctx = context.WithValue(ctx, internal.AudioPathKey, audio)
680
681
mockReader.EXPECT().Open(audio).Return(audioFile, nil)
682
mockReader.EXPECT().ReadBufferFromFile(audioFile).Return([]byte("RIFFxxxxWAVE..."), nil)
683
mockReader.EXPECT().ReadFile(audio).Return([]byte("audio-bytes"), nil)
684
685
expectedBody, err := createBody([]api.Message{
686
{Role: client.SystemRole, Content: systemRole},
687
{Role: client.UserRole, Content: query},
688
{Role: client.UserRole, Content: []api.AudioContent{{
689
Type: "input_audio",
690
InputAudio: struct {
691
Data string `json:"data"`
692
Format string `json:"format"`
693
}{
694
Data: "YXVkaW8tYnl0ZXM=", // base64 of "audio-bytes"
695
Format: "wav",
696
},
697
}}},
698
}, false)
699
Expect(err).NotTo(HaveOccurred())
700
701
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
702
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)
703
704
_, _, _ = subject.Query(ctx, query)
705
})
706
})
707
708
when("the model is o1-pro or gpt-5", func() {
709
models := []string{"o1-pro", "gpt-5"}
710
711
for _, m := range models {
712
m := m // capture
713
when(fmt.Sprintf("the model is %s", m), func() {
714
const (
715
query = "what's the weather"
716
systemRole = "you are helpful"
717
totalTokens = 777
718
)
719
720
it.Before(func() {
721
config.Model = m
722
config.Role = systemRole
723
factory.withoutHistory()
724
})
725
726
it("returns the output_text when present", func() {
727
subject := factory.buildClientWithoutConfig()
728
subject.Config.Model = m
729
subject.Config.Role = systemRole
730
731
answer := "yes, it does"
732
messages := []api.Message{
733
{Role: client.SystemRole, Content: systemRole},
734
{Role: client.UserRole, Content: query},
735
}
736
737
body, err := json.Marshal(api.ResponsesRequest{
738
Model: subject.Config.Model,
739
Input: messages,
740
MaxOutputTokens: subject.Config.MaxTokens,
741
Reasoning: api.Reasoning{Effort: "low"},
742
Stream: false,
743
Temperature: subject.Config.Temperature,
744
TopP: subject.Config.TopP,
745
})
746
Expect(err).NotTo(HaveOccurred())
747
748
mockTimer.EXPECT().Now().Times(3)
749
mockHistoryStore.EXPECT().Write(gomock.Any())
750
751
response := api.ResponsesResponse{
752
Output: []api.Output{{
753
Type: "message",
754
Content: []api.Content{{Type: "output_text", Text: answer}},
755
}},
756
Usage: api.TokenUsage{TotalTokens: 42},
757
}
758
raw, _ := json.Marshal(response)
759
760
mockCaller.EXPECT().
761
Post(subject.Config.URL+"/v1/responses", body, false).
762
Return(raw, nil)
763
764
text, tokens, err := subject.Query(context.Background(), query)
765
Expect(err).NotTo(HaveOccurred())
766
Expect(text).To(Equal(answer))
767
Expect(tokens).To(Equal(42))
768
})
769
770
it("errors when no output blocks are present", func() {
771
subject := factory.buildClientWithoutConfig()
772
subject.Config.Model = m
773
subject.Config.Role = systemRole
774
775
messages := []api.Message{
776
{Role: client.SystemRole, Content: systemRole},
777
{Role: client.UserRole, Content: query},
778
}
779
780
body, _ := json.Marshal(api.ResponsesRequest{
781
Model: subject.Config.Model,
782
Input: messages,
783
MaxOutputTokens: subject.Config.MaxTokens,
784
Reasoning: api.Reasoning{Effort: "low"},
785
Stream: false,
786
Temperature: subject.Config.Temperature,
787
TopP: subject.Config.TopP,
788
})
789
790
mockTimer.EXPECT().Now().Times(2)
791
792
response := api.ResponsesResponse{
793
Output: []api.Output{},
794
Usage: api.TokenUsage{TotalTokens: totalTokens},
795
}
796
raw, _ := json.Marshal(response)
797
798
mockCaller.EXPECT().
799
Post(subject.Config.URL+"/v1/responses", body, false).
800
Return(raw, nil)
801
802
_, _, err := subject.Query(context.Background(), query)
803
Expect(err).To(HaveOccurred())
804
Expect(err.Error()).To(Equal("no response returned"))
805
})
806
807
it("errors when message has no output_text", func() {
808
subject := factory.buildClientWithoutConfig()
809
subject.Config.Model = m
810
subject.Config.Role = systemRole
811
812
messages := []api.Message{
813
{Role: client.SystemRole, Content: systemRole},
814
{Role: client.UserRole, Content: query},
815
}
816
817
body, _ := json.Marshal(api.ResponsesRequest{
818
Model: subject.Config.Model,
819
Input: messages,
820
MaxOutputTokens: subject.Config.MaxTokens,
821
Reasoning: api.Reasoning{Effort: "low"},
822
Stream: false,
823
Temperature: subject.Config.Temperature,
824
TopP: subject.Config.TopP,
825
})
826
827
mockTimer.EXPECT().Now().Times(2)
828
829
response := api.ResponsesResponse{
830
Output: []api.Output{{
831
Type: "message",
832
Content: []api.Content{{Type: "refusal", Text: "nope"}},
833
}},
834
Usage: api.TokenUsage{TotalTokens: totalTokens},
835
}
836
raw, _ := json.Marshal(response)
837
838
mockCaller.EXPECT().
839
Post(subject.Config.URL+"/v1/responses", body, false).
840
Return(raw, nil)
841
842
_, _, err := subject.Query(context.Background(), query)
843
Expect(err).To(HaveOccurred())
844
Expect(err.Error()).To(Equal("no response returned"))
845
})
846
})
847
}
848
})
849
})
850
when("Stream()", func() {
851
var (
852
body []byte
853
messages []api.Message
854
err error
855
)
856
857
it("throws an error when the http callout fails", func() {
858
factory.withoutHistory()
859
subject := factory.buildClientWithoutConfig()
860
861
messages = createMessages(nil, query)
862
body, err = createBody(messages, true)
863
Expect(err).NotTo(HaveOccurred())
864
865
errorMsg := "error message"
866
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, body, true).Return(nil, errors.New(errorMsg))
867
868
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
869
870
err := subject.Stream(context.Background(), query)
871
Expect(err).To(HaveOccurred())
872
Expect(err.Error()).To(Equal(errorMsg))
873
})
874
when("a valid http response is received", func() {
875
const answer = "answer"
876
877
testValidHTTPResponse := func(subject *client.Client, hs []history.History, expectedBody []byte) {
878
messages = createMessages(nil, query)
879
body, err = createBody(messages, true)
880
Expect(err).NotTo(HaveOccurred())
881
882
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, true).Return([]byte(answer), nil)
883
884
mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()
885
886
messages = createMessages(hs, query)
887
888
hs = []history.History{}
889
890
for _, message := range messages {
891
hs = append(hs, history.History{
892
Message: message,
893
})
894
}
895
896
mockHistoryStore.EXPECT().Write(append(hs, history.History{
897
Message: api.Message{
898
Role: client.AssistantRole,
899
Content: answer,
900
},
901
}))
902
903
err := subject.Stream(context.Background(), query)
904
Expect(err).NotTo(HaveOccurred())
905
}
906
907
it("returns the expected result for an empty history", func() {
908
factory.withHistory(nil)
909
subject := factory.buildClientWithoutConfig()
910
911
messages = createMessages(nil, query)
912
body, err = createBody(messages, true)
913
Expect(err).NotTo(HaveOccurred())
914
915
testValidHTTPResponse(subject, nil, body)
916
})
917
it("returns the expected result for a non-empty history", func() {
918
h := []history.History{
919
{
920
Message: api.Message{
921
Role: client.SystemRole,
922
Content: config.Role,
923
},
924
},
925
{
926
Message: api.Message{
927
Role: client.UserRole,
928
Content: "question x",
929
},
930
},
931
{
932
Message: api.Message{
933
Role: client.AssistantRole,
934
Content: "answer x",
935
},
936
},
937
}
938
factory.withHistory(h)
939
subject := factory.buildClientWithoutConfig()
940
941
messages = createMessages(h, query)
942
body, err = createBody(messages, true)
943
Expect(err).NotTo(HaveOccurred())
944
945
testValidHTTPResponse(subject, h, body)
946
})
947
})
948
})
949
when("SynthesizeSpeech()", func() {
950
const (
951
inputText = "mock-input"
952
outputFile = "mock-output"
953
outputFileType = "mp3"
954
errorText = "mock error occurred"
955
)
956
957
var (
958
subject *client.Client
959
fileName = outputFile + "." + outputFileType
960
body []byte
961
response []byte
962
)
963
it.Before(func() {
964
subject = factory.buildClientWithoutConfig()
965
request := api.Speech{
966
Model: subject.Config.Model,
967
Voice: subject.Config.Voice,
968
Input: inputText,
969
ResponseFormat: outputFileType,
970
}
971
var err error
972
body, err = json.Marshal(request)
973
Expect(err).NotTo(HaveOccurred())
974
975
response = []byte("mock response")
976
})
977
it("throws an error when the http call fails", func() {
978
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).Return(nil, errors.New(errorText))
979
980
err := subject.SynthesizeSpeech(inputText, fileName)
981
Expect(err).To(HaveOccurred())
982
Expect(err.Error()).To(ContainSubstring(errorText))
983
})
984
it("throws an error when a file cannot be created", func() {
985
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).Return(response, nil)
986
mockWriter.EXPECT().Create(fileName).Return(nil, errors.New(errorText))
987
988
err := subject.SynthesizeSpeech(inputText, fileName)
989
Expect(err).To(HaveOccurred())
990
Expect(err.Error()).To(ContainSubstring(errorText))
991
})
992
it("throws an error when bytes cannot be written to the output file", func() {
993
file, err := os.Open(os.DevNull)
994
Expect(err).NotTo(HaveOccurred())
995
defer file.Close()
996
997
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).Return(response, nil)
998
mockWriter.EXPECT().Create(fileName).Return(file, nil)
999
mockWriter.EXPECT().Write(file, response).Return(errors.New(errorText))
1000
1001
err = subject.SynthesizeSpeech(inputText, fileName)
1002
Expect(err).To(HaveOccurred())
1003
Expect(err.Error()).To(ContainSubstring(errorText))
1004
})
1005
it("succeeds when no errors occurred", func() {
1006
file, err := os.Open(os.DevNull)
1007
Expect(err).NotTo(HaveOccurred())
1008
defer file.Close()
1009
1010
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).Return(response, nil)
1011
mockWriter.EXPECT().Create(fileName).Return(file, nil)
1012
mockWriter.EXPECT().Write(file, response).Return(nil)
1013
1014
err = subject.SynthesizeSpeech(inputText, fileName)
1015
Expect(err).NotTo(HaveOccurred())
1016
})
1017
})
1018
when("GenerateImage()", func() {
1019
const (
1020
inputText = "draw a happy dog"
1021
outputFile = "dog.png"
1022
errorText = "mock error occurred"
1023
)
1024
1025
var (
1026
subject *client.Client
1027
body []byte
1028
)
1029
1030
it.Before(func() {
1031
subject = factory.buildClientWithoutConfig()
1032
request := api.Draw{
1033
Model: subject.Config.Model,
1034
Prompt: inputText,
1035
}
1036
var err error
1037
body, err = json.Marshal(request)
1038
Expect(err).NotTo(HaveOccurred())
1039
})
1040
it("throws an error when the http call fails", func() {
1041
mockCaller.EXPECT().
1042
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
1043
Return(nil, errors.New(errorText))
1044
1045
err := subject.GenerateImage(inputText, outputFile)
1046
Expect(err).To(HaveOccurred())
1047
Expect(err.Error()).To(ContainSubstring(errorText))
1048
})
1049
it("throws an error when no image data is returned", func() {
1050
mockCaller.EXPECT().
1051
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
1052
Return([]byte(`{"data":[]}`), nil)
1053
1054
err := subject.GenerateImage(inputText, outputFile)
1055
Expect(err).To(HaveOccurred())
1056
Expect(err.Error()).To(ContainSubstring("no image data returned"))
1057
})
1058
it("throws an error when base64 is invalid", func() {
1059
mockCaller.EXPECT().
1060
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
1061
Return([]byte(`{"data":[{"b64_json":"!!notbase64!!"}]}`), nil)
1062
1063
err := subject.GenerateImage(inputText, outputFile)
1064
Expect(err).To(HaveOccurred())
1065
Expect(err.Error()).To(ContainSubstring("failed to decode base64 image"))
1066
})
1067
it("throws an error when a file cannot be created", func() {
1068
valid := base64.StdEncoding.EncodeToString([]byte("image-bytes"))
1069
1070
mockCaller.EXPECT().
1071
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
1072
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
1073
1074
mockWriter.EXPECT().Create(outputFile).Return(nil, errors.New(errorText))
1075
1076
err := subject.GenerateImage(inputText, outputFile)
1077
Expect(err).To(HaveOccurred())
1078
Expect(err.Error()).To(ContainSubstring(errorText))
1079
})
1080
it("throws an error when bytes cannot be written to the file", func() {
1081
valid := base64.StdEncoding.EncodeToString([]byte("image-bytes"))
1082
file, err := os.Open(os.DevNull)
1083
Expect(err).NotTo(HaveOccurred())
1084
defer file.Close()
1085
1086
mockCaller.EXPECT().
1087
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
1088
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
1089
1090
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
1091
mockWriter.EXPECT().Write(file, []byte("image-bytes")).Return(errors.New(errorText))
1092
1093
err = subject.GenerateImage(inputText, outputFile)
1094
Expect(err).To(HaveOccurred())
1095
Expect(err.Error()).To(ContainSubstring(errorText))
1096
})
1097
it("succeeds when all steps complete", func() {
1098
valid := base64.StdEncoding.EncodeToString([]byte("image-bytes"))
1099
file, err := os.Open(os.DevNull)
1100
Expect(err).NotTo(HaveOccurred())
1101
defer file.Close()
1102
1103
mockCaller.EXPECT().
1104
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
1105
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
1106
1107
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
1108
mockWriter.EXPECT().Write(file, []byte("image-bytes")).Return(nil)
1109
1110
err = subject.GenerateImage(inputText, outputFile)
1111
Expect(err).NotTo(HaveOccurred())
1112
})
1113
})
1114
when("EditImage()", func() {
1115
const (
1116
inputText = "give the dog sunglasses"
1117
inputFile = "dog.png"
1118
outputFile = "dog_cool.png"
1119
errorText = "mock error occurred"
1120
)
1121
1122
var (
1123
subject *client.Client
1124
validB64 string
1125
imageBytes = []byte("image-bytes")
1126
respBytes []byte
1127
)
1128
1129
it.Before(func() {
1130
subject = factory.buildClientWithoutConfig()
1131
validB64 = base64.StdEncoding.EncodeToString(imageBytes)
1132
respBytes = []byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, validB64))
1133
})
1134
1135
it("returns error when input file can't be opened", func() {
1136
mockReader.EXPECT().Open(inputFile).Return(nil, errors.New(errorText))
1137
1138
err := subject.EditImage(inputText, inputFile, outputFile)
1139
Expect(err).To(HaveOccurred())
1140
Expect(err.Error()).To(ContainSubstring("failed to open input image"))
1141
})
1142
it("returns error on invalid mime type", func() {
1143
file := openDummy()
1144
mockReader.EXPECT().Open(inputFile).Return(file, nil).Times(2)
1145
mockReader.EXPECT().ReadBufferFromFile(file).Return([]byte("not an image"), nil)
1146
1147
err := subject.EditImage(inputText, inputFile, outputFile)
1148
Expect(err).To(HaveOccurred())
1149
Expect(err.Error()).To(ContainSubstring("unsupported MIME type"))
1150
})
1151
it("returns error when HTTP call fails", func() {
1152
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
1153
return openDummy(), nil
1154
}).Times(2)
1155
1156
mockReader.EXPECT().
1157
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
1158
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
1159
1160
mockCaller.EXPECT().
1161
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
1162
Return(nil, errors.New(errorText))
1163
1164
err := subject.EditImage(inputText, inputFile, outputFile)
1165
Expect(err).To(HaveOccurred())
1166
Expect(err.Error()).To(ContainSubstring("failed to edit image"))
1167
})
1168
it("returns error when base64 is invalid", func() {
1169
invalidResp := []byte(`{"data":[{"b64_json":"!notbase64"}]}`)
1170
1171
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
1172
return openDummy(), nil
1173
}).Times(2)
1174
1175
mockReader.EXPECT().
1176
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
1177
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
1178
1179
mockCaller.EXPECT().
1180
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
1181
Return(invalidResp, nil)
1182
1183
err := subject.EditImage(inputText, inputFile, outputFile)
1184
Expect(err).To(HaveOccurred())
1185
Expect(err.Error()).To(ContainSubstring("failed to decode base64 image"))
1186
})
1187
it("writes image when all steps succeed", func() {
1188
file := openDummy()
1189
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
1190
return openDummy(), nil
1191
}).Times(2)
1192
1193
mockReader.EXPECT().
1194
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
1195
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
1196
1197
mockCaller.EXPECT().
1198
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
1199
Return(respBytes, nil)
1200
1201
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
1202
mockWriter.EXPECT().Write(file, imageBytes).Return(nil)
1203
1204
err := subject.EditImage(inputText, inputFile, outputFile)
1205
Expect(err).NotTo(HaveOccurred())
1206
})
1207
})
1208
when("Transcribe()", func() {
1209
const audioPath = "path/to/audio.wav"
1210
const transcribedText = "Hello, this is a test."
1211
1212
it("returns an error if the audio file cannot be opened", func() {
1213
subject := factory.buildClientWithoutConfig()
1214
1215
mockHistoryStore.EXPECT().Read().Return(nil, nil)
1216
mockTimer.EXPECT().Now().Times(1)
1217
1218
mockReader.EXPECT().Open(audioPath).Return(nil, errors.New("cannot open"))
1219
1220
_, err := subject.Transcribe(audioPath)
1221
Expect(err).To(HaveOccurred())
1222
Expect(err.Error()).To(ContainSubstring("cannot open"))
1223
})
1224
1225
it("returns an error if copying audio content fails", func() {
1226
subject := factory.buildClientWithoutConfig()
1227
1228
mockHistoryStore.EXPECT().Read().Return(nil, nil)
1229
mockTimer.EXPECT().Now().Times(1)
1230
1231
reader, writer, err := os.Pipe()
1232
Expect(err).NotTo(HaveOccurred())
1233
1234
// Immediately close writer so reader will return EOF
1235
_ = writer.Close()
1236
1237
mockReader.EXPECT().Open(audioPath).Return(reader, nil)
1238
1239
mockCaller.EXPECT().
1240
PostWithHeaders(subject.Config.URL+subject.Config.TranscriptionsPath, gomock.Any(), gomock.Any())
1241
1242
_, err = subject.Transcribe(audioPath)
1243
Expect(err).To(HaveOccurred())
1244
Expect(err.Error()).To(ContainSubstring("failed"))
1245
})
1246
1247
it("returns an error if the API call fails", func() {
1248
subject := factory.buildClientWithoutConfig()
1249
1250
mockHistoryStore.EXPECT().Read().Return(nil, nil)
1251
mockTimer.EXPECT().Now().Times(1)
1252
1253
file, err := os.Open(os.DevNull)
1254
Expect(err).NotTo(HaveOccurred())
1255
defer file.Close()
1256
1257
mockReader.EXPECT().Open(audioPath).Return(file, nil)
1258
1259
mockCaller.EXPECT().
1260
PostWithHeaders(subject.Config.URL+subject.Config.TranscriptionsPath, gomock.Any(), gomock.Any()).
1261
Return(nil, errors.New("network error"))
1262
1263
_, err = subject.Transcribe(audioPath)
1264
Expect(err).To(HaveOccurred())
1265
Expect(err.Error()).To(ContainSubstring("network error"))
1266
})
1267
1268
it("returns the transcribed text when successful", func() {
1269
subject := factory.buildClientWithoutConfig()
1270
1271
mockHistoryStore.EXPECT().Read().Return(nil, nil)
1272
1273
now := time.Now()
1274
mockTimer.EXPECT().Now().Return(now).Times(3)
1275
1276
file, err := os.Open(os.DevNull)
1277
Expect(err).NotTo(HaveOccurred())
1278
defer file.Close()
1279
1280
mockReader.EXPECT().Open(audioPath).Return(file, nil)
1281
1282
resp := []byte(`{"text": "Hello, this is a test."}`)
1283
mockCaller.EXPECT().
1284
PostWithHeaders(subject.Config.URL+subject.Config.TranscriptionsPath, gomock.Any(), gomock.Any()).
1285
Return(resp, nil)
1286
1287
expectedHistory := []history.History{
1288
{
1289
Message: api.Message{
1290
Role: client.SystemRole,
1291
Content: subject.Config.Role,
1292
},
1293
Timestamp: now,
1294
},
1295
{
1296
Message: api.Message{
1297
Role: client.UserRole,
1298
Content: "[transcribe] audio.wav",
1299
},
1300
Timestamp: now,
1301
},
1302
{
1303
Message: api.Message{
1304
Role: client.AssistantRole,
1305
Content: transcribedText,
1306
},
1307
Timestamp: now,
1308
},
1309
}
1310
1311
mockHistoryStore.EXPECT().Write(expectedHistory)
1312
1313
text, err := subject.Transcribe(audioPath)
1314
Expect(err).NotTo(HaveOccurred())
1315
Expect(text).To(Equal(transcribedText))
1316
})
1317
})
1318
when("ListModels()", func() {
1319
it("throws an error when the http callout fails", func() {
1320
subject := factory.buildClientWithoutConfig()
1321
1322
errorMsg := "error message"
1323
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(nil, errors.New(errorMsg))
1324
1325
_, err := subject.ListModels()
1326
Expect(err).To(HaveOccurred())
1327
Expect(err.Error()).To(Equal(errorMsg))
1328
})
1329
it("throws an error when the response is empty", func() {
1330
subject := factory.buildClientWithoutConfig()
1331
1332
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(nil, nil)
1333
1334
_, err := subject.ListModels()
1335
Expect(err).To(HaveOccurred())
1336
Expect(err.Error()).To(Equal("empty response"))
1337
})
1338
it("throws an error when the response is a malformed json", func() {
1339
subject := factory.buildClientWithoutConfig()
1340
1341
malformed := `{"invalid":"json"` // missing closing brace
1342
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return([]byte(malformed), nil)
1343
1344
_, err := subject.ListModels()
1345
Expect(err).To(HaveOccurred())
1346
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
1347
})
1348
it("filters gpt and o1 models as expected and puts them in alphabetical order", func() {
1349
subject := factory.buildClientWithoutConfig()
1350
1351
response, err := test.FileToBytes("models.json")
1352
Expect(err).NotTo(HaveOccurred())
1353
1354
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(response, nil)
1355
1356
result, err := subject.ListModels()
1357
Expect(err).NotTo(HaveOccurred())
1358
Expect(result).NotTo(BeEmpty())
1359
Expect(result).To(HaveLen(5))
1360
Expect(result[0]).To(Equal("- gpt-3.5-env-model"))
1361
Expect(result[1]).To(Equal("* gpt-3.5-turbo (current)"))
1362
Expect(result[2]).To(Equal("- gpt-3.5-turbo-0301"))
1363
Expect(result[3]).To(Equal("- gpt-4o"))
1364
Expect(result[4]).To(Equal("- o1-mini"))
1365
})
1366
})
1367
when("ProvideContext()", func() {
1368
it("updates the history with the provided context", func() {
1369
subject := factory.buildClientWithoutConfig()
1370
1371
chatContext := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
1372
mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
1373
1374
mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()
1375
1376
subject.ProvideContext(chatContext)
1377
1378
Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context
1379
1380
systemMessage := subject.History[0]
1381
Expect(systemMessage.Role).To(Equal(client.SystemRole))
1382
Expect(systemMessage.Content).To(Equal(config.Role))
1383
1384
contextMessage := subject.History[1]
1385
Expect(contextMessage.Role).To(Equal(client.UserRole))
1386
Expect(contextMessage.Content).To(Equal(chatContext))
1387
})
1388
it("behaves as expected with a non empty initial history", func() {
1389
subject := factory.buildClientWithoutConfig()
1390
1391
subject.History = []history.History{
1392
{
1393
Message: api.Message{
1394
Role: client.SystemRole,
1395
Content: "system message",
1396
},
1397
},
1398
{
1399
Message: api.Message{
1400
Role: client.UserRole,
1401
},
1402
},
1403
}
1404
1405
mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()
1406
1407
chatContext := "test context"
1408
subject.ProvideContext(chatContext)
1409
1410
Expect(len(subject.History)).To(Equal(3))
1411
1412
contextMessage := subject.History[2]
1413
Expect(contextMessage.Role).To(Equal(client.UserRole))
1414
Expect(contextMessage.Content).To(Equal(chatContext))
1415
})
1416
})
1417
when("InjectMCPContext()", func() {
1418
var subject *client.Client
1419
1420
const (
1421
function = "mock-function"
1422
version = "mock-version"
1423
param = "mock-param"
1424
value = "mock-value"
1425
apifyKey = "mock-key"
1426
endpoint = client.ApifyURL + function + client.ApifyPath
1427
)
1428
1429
req := api.MCPRequest{
1430
Provider: utils.ApifyProvider,
1431
Function: function,
1432
Version: version,
1433
Params: map[string]interface{}{
1434
param: value,
1435
},
1436
}
1437
1438
it.Before(func() {
1439
subject = factory.buildClientWithoutConfig()
1440
subject.Config.ApifyAPIKey = apifyKey
1441
})
1442
1443
it("throws an error when the apify API key is missing and the apify provider is used", func() {
1444
subject.Config.ApifyAPIKey = ""
1445
1446
err := subject.InjectMCPContext(req)
1447
Expect(err).To(HaveOccurred())
1448
Expect(err.Error()).To(ContainSubstring(utils.ApifyProvider))
1449
})
1450
it("is not reliant on specific provider casing", func() {
1451
subject.Config.ApifyAPIKey = ""
1452
1453
req.Provider = "ApIfY"
1454
1455
err := subject.InjectMCPContext(req)
1456
Expect(err).To(HaveOccurred())
1457
Expect(err.Error()).To(ContainSubstring(utils.ApifyProvider))
1458
})
1459
it("throws an error when history tracking is disabled", func() {
1460
subject.Config.OmitHistory = true
1461
1462
err := subject.InjectMCPContext(req)
1463
Expect(err).To(HaveOccurred())
1464
Expect(err).To(MatchError(client.ErrHistoryTracking))
1465
})
1466
it("throws an error when the provider is not supported", func() {
1467
req.Provider = "not-supported"
1468
1469
err := subject.InjectMCPContext(req)
1470
Expect(err).To(HaveOccurred())
1471
Expect(err).To(MatchError(client.ErrUnsupportedProvider))
1472
})
1473
it("throws an error when the http call fails", func() {
1474
msg := "error message"
1475
1476
mockCaller.EXPECT().
1477
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).Return(nil, errors.New(msg))
1478
1479
err := subject.InjectMCPContext(req)
1480
Expect(err).To(HaveOccurred())
1481
Expect(err).To(MatchError(msg))
1482
})
1483
it("throws an error when history writing fails", func() {
1484
mockCaller.EXPECT().
1485
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).Return([]byte(`{"key":"value"}`), nil)
1486
1487
mockHistoryStore.EXPECT().Read().Times(1)
1488
mockTimer.EXPECT().Now().Times(2)
1489
1490
msg := "error message"
1491
mockHistoryStore.EXPECT().Write(gomock.Any()).Return(errors.New(msg))
1492
1493
err := subject.InjectMCPContext(req)
1494
Expect(err).To(HaveOccurred())
1495
Expect(err).To(MatchError(msg))
1496
})
1497
it("adds the formatted MCP response to history (array data)", func() {
1498
mockCaller.EXPECT().
1499
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).
1500
Return([]byte(`[{"temperature":"15C","condition":"Sunny"}]`), nil)
1501
1502
mockHistoryStore.EXPECT().Read().Times(1)
1503
mockTimer.EXPECT().Now().Times(2)
1504
1505
mockHistoryStore.EXPECT().Write(gomock.Any()).
1506
DoAndReturn(func(h []history.History) error {
1507
Expect(len(h)).To(Equal(2))
1508
last := h[len(h)-1]
1509
Expect(last.Message.Role).To(Equal("function"))
1510
Expect(last.Message.Name).To(Equal("mock-function"))
1511
Expect(last.Message.Content).To(ContainSubstring("Temperature: 15C"))
1512
Expect(last.Message.Content).To(ContainSubstring("Condition: Sunny"))
1513
Expect(last.Message.Content).To(ContainSubstring("[MCP: mock-function]"))
1514
return nil
1515
})
1516
1517
err := subject.InjectMCPContext(req)
1518
Expect(err).NotTo(HaveOccurred())
1519
})
1520
it("adds the formatted MCP response to history (single object)", func() {
1521
mockCaller.EXPECT().
1522
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).
1523
Return([]byte(`{"foo":"bar","baz":"qux"}`), nil)
1524
1525
mockHistoryStore.EXPECT().Read().Times(1)
1526
mockTimer.EXPECT().Now().Times(2)
1527
1528
mockHistoryStore.EXPECT().Write(gomock.Any()).
1529
DoAndReturn(func(h []history.History) error {
1530
Expect(len(h)).To(Equal(2))
1531
Expect(h[len(h)-1].Message.Content).To(ContainSubstring("Foo: bar"))
1532
Expect(h[len(h)-1].Message.Content).To(ContainSubstring("Baz: qux"))
1533
Expect(h[len(h)-1].Message.Role).To(Equal("function"))
1534
Expect(h[len(h)-1].Message.Name).To(Equal("mock-function"))
1535
return nil
1536
})
1537
1538
err := subject.InjectMCPContext(req)
1539
Expect(err).NotTo(HaveOccurred())
1540
})
1541
it("adds fallback message when array response is empty", func() {
1542
mockCaller.EXPECT().
1543
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).
1544
Return([]byte(`[]`), nil)
1545
1546
mockHistoryStore.EXPECT().Read().Times(1)
1547
mockTimer.EXPECT().Now().Times(2)
1548
1549
mockHistoryStore.EXPECT().Write(gomock.Any()).
1550
DoAndReturn(func(h []history.History) error {
1551
Expect(h[len(h)-1].Message.Content).To(ContainSubstring("no data returned"))
1552
return nil
1553
})
1554
1555
err := subject.InjectMCPContext(req)
1556
Expect(err).NotTo(HaveOccurred())
1557
})
1558
it("adds fallback message when array contains non-object items", func() {
1559
mockCaller.EXPECT().
1560
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).
1561
Return([]byte(`[42, true, "string"]`), nil)
1562
1563
mockHistoryStore.EXPECT().Read().Times(1)
1564
mockTimer.EXPECT().Now().Times(2)
1565
1566
mockHistoryStore.EXPECT().Write(gomock.Any()).
1567
DoAndReturn(func(h []history.History) error {
1568
Expect(h[len(h)-1].Message.Content).To(ContainSubstring("unexpected response format"))
1569
return nil
1570
})
1571
1572
err := subject.InjectMCPContext(req)
1573
Expect(err).NotTo(HaveOccurred())
1574
})
1575
it("adds fallback message when response is invalid JSON", func() {
1576
mockCaller.EXPECT().
1577
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).
1578
Return([]byte(`{invalid json}`), nil)
1579
1580
mockHistoryStore.EXPECT().Read().Times(1)
1581
mockTimer.EXPECT().Now().Times(2)
1582
1583
mockHistoryStore.EXPECT().Write(gomock.Any()).
1584
DoAndReturn(func(h []history.History) error {
1585
Expect(h[len(h)-1].Message.Content).To(ContainSubstring("failed to decode response"))
1586
return nil
1587
})
1588
1589
err := subject.InjectMCPContext(req)
1590
Expect(err).NotTo(HaveOccurred())
1591
})
1592
it("adds fallback message when top-level JSON is a string", func() {
1593
mockCaller.EXPECT().
1594
PostWithHeaders(endpoint, gomock.Any(), gomock.Any()).
1595
Return([]byte(`"hello world"`), nil)
1596
1597
mockHistoryStore.EXPECT().Read().Times(1)
1598
mockTimer.EXPECT().Now().Times(2)
1599
1600
mockHistoryStore.EXPECT().Write(gomock.Any()).
1601
DoAndReturn(func(h []history.History) error {
1602
Expect(h[len(h)-1].Message.Content).To(ContainSubstring("unexpected response format"))
1603
return nil
1604
})
1605
1606
err := subject.InjectMCPContext(req)
1607
Expect(err).NotTo(HaveOccurred())
1608
})
1609
})
1610
}
1611
1612
func openDummy() *os.File {
1613
// Use os.Pipe to get an *os.File without needing a real disk file.
1614
r, w, _ := os.Pipe()
1615
go func() {
1616
_, _ = io.Copy(w, bytes.NewBuffer([]byte("\x89PNG\r\n\x1a\n")))
1617
_ = w.Close()
1618
}()
1619
return r
1620
}
1621
1622
func createBody(messages []api.Message, stream bool) ([]byte, error) {
1623
req := api.CompletionsRequest{
1624
Model: config.Model,
1625
Messages: messages,
1626
Stream: stream,
1627
Temperature: config.Temperature,
1628
TopP: config.TopP,
1629
FrequencyPenalty: config.FrequencyPenalty,
1630
MaxTokens: config.MaxTokens,
1631
PresencePenalty: config.PresencePenalty,
1632
Seed: config.Seed,
1633
}
1634
1635
return json.Marshal(req)
1636
}
1637
1638
func createMessages(historyEntries []history.History, query string) []api.Message {
1639
var messages []api.Message
1640
1641
if len(historyEntries) == 0 {
1642
messages = append(messages, api.Message{
1643
Role: client.SystemRole,
1644
Content: config.Role,
1645
})
1646
} else {
1647
for _, entry := range historyEntries {
1648
messages = append(messages, entry.Message)
1649
}
1650
}
1651
1652
messages = append(messages, api.Message{
1653
Role: client.UserRole,
1654
Content: query,
1655
})
1656
1657
return messages
1658
}
1659
1660
type clientFactory struct {
1661
mockHistoryStore *MockStore
1662
}
1663
1664
func newClientFactory(mhs *MockStore) *clientFactory {
1665
return &clientFactory{
1666
mockHistoryStore: mhs,
1667
}
1668
}
1669
1670
func (f *clientFactory) buildClientWithoutConfig() *client.Client {
1671
f.mockHistoryStore.EXPECT().SetThread(config.Thread).Times(1)
1672
1673
c := client.New(mockCallerFactory, f.mockHistoryStore, mockTimer, mockReader, mockWriter, MockConfig(), commandLineMode)
1674
1675
return c.WithContextWindow(config.ContextWindow)
1676
}
1677
1678
func (f *clientFactory) withoutHistory() {
1679
f.mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)
1680
}
1681
1682
func (f *clientFactory) withHistory(history []history.History) {
1683
f.mockHistoryStore.EXPECT().Read().Return(history, nil).Times(1)
1684
}
1685
1686
func mockCallerFactory(_ config2.Config) http.Caller {
1687
return mockCaller
1688
}
1689
1690
func MockConfig() config2.Config {
1691
return config2.Config{
1692
Name: "mock-openai",
1693
APIKey: "mock-api-key",
1694
Model: "gpt-3.5-turbo",
1695
MaxTokens: 100,
1696
ContextWindow: 50,
1697
Role: "You are a test assistant.",
1698
Temperature: 0.7,
1699
TopP: 0.9,
1700
FrequencyPenalty: 0.1,
1701
PresencePenalty: 0.2,
1702
Thread: "mock-thread",
1703
OmitHistory: false,
1704
URL: "https://api.mock-openai.com",
1705
CompletionsPath: "/v1/test/completions",
1706
ModelsPath: "/v1/test/models",
1707
AuthHeader: "MockAuthorization",
1708
AuthTokenPrefix: "MockBearer ",
1709
CommandPrompt: "[mock-datetime] [Q%counter] [%usage]",
1710
OutputPrompt: "[mock-output]",
1711
AutoCreateNewThread: true,
1712
TrackTokenUsage: true,
1713
SkipTLSVerify: false,
1714
Seed: 1,
1715
Effort: "low",
1716
ResponsesPath: "/v1/responses",
1717
Voice: "mock-voice",
1718
TranscriptionsPath: "/v1/test/transcriptions",
1719
SpeechPath: "/v1/test/speech",
1720
}
1721
}
1722
1723