Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/llm_test.go
3431 views
1
package client_test
2
3
import (
4
"context"
5
"encoding/json"
6
"errors"
7
"fmt"
8
"testing"
9
"time"
10
11
"github.com/golang/mock/gomock"
12
"github.com/kardolus/chatgpt-cli/api"
13
"github.com/kardolus/chatgpt-cli/api/client"
14
"github.com/kardolus/chatgpt-cli/history"
15
"github.com/kardolus/chatgpt-cli/test"
16
17
. "github.com/onsi/gomega"
18
"github.com/sclevine/spec"
19
)
20
21
func testLLM(t *testing.T, when spec.G, it spec.S) {
22
const query = "test query"
23
24
when("LLM()", func() {
25
when("Query()", func() {
26
var (
27
body []byte
28
messages []api.Message
29
err error
30
)
31
32
type TestCase struct {
33
description string
34
setupPostReturn func() ([]byte, error)
35
postError error
36
expectedError string
37
}
38
39
tests := []TestCase{
40
{
41
description: "throws an error when the http callout fails",
42
setupPostReturn: func() ([]byte, error) { return nil, nil },
43
postError: errors.New("error message"),
44
expectedError: "error message",
45
},
46
{
47
description: "throws an error when the response is empty",
48
setupPostReturn: func() ([]byte, error) { return nil, nil },
49
postError: nil,
50
expectedError: "empty response",
51
},
52
{
53
description: "throws an error when the response is a malformed json",
54
setupPostReturn: func() ([]byte, error) {
55
malformed := `{"invalid":"json"` // missing closing brace
56
return []byte(malformed), nil
57
},
58
postError: nil,
59
expectedError: "failed to decode response:",
60
},
61
{
62
description: "throws an error when the response is missing Choices",
63
setupPostReturn: func() ([]byte, error) {
64
response := &api.CompletionsResponse{
65
ID: "id",
66
Object: "object",
67
Created: 0,
68
Model: "model",
69
Choices: []api.Choice{},
70
}
71
72
respBytes, err := json.Marshal(response)
73
return respBytes, err
74
},
75
postError: nil,
76
expectedError: "no responses returned",
77
},
78
{
79
description: "throws an error when the response cannot be casted to a string",
80
setupPostReturn: func() ([]byte, error) {
81
response := &api.CompletionsResponse{
82
ID: "id",
83
Object: "object",
84
Created: 0,
85
Model: "model",
86
Choices: []api.Choice{
87
{
88
Message: api.Message{
89
Role: client.AssistantRole,
90
Content: 123, // cannot be converted to a string
91
},
92
FinishReason: "",
93
Index: 0,
94
},
95
},
96
}
97
98
respBytes, err := json.Marshal(response)
99
return respBytes, err
100
},
101
postError: nil,
102
expectedError: "response cannot be converted to a string",
103
},
104
}
105
106
for _, tt := range tests {
107
tt := tt
108
it(tt.description, func() {
109
factory.withoutHistory()
110
subject := factory.buildClientWithoutConfig()
111
112
messages = createMessages(nil, query)
113
body, err = createBody(messages, false)
114
Expect(err).NotTo(HaveOccurred())
115
116
respBytes, err := tt.setupPostReturn()
117
Expect(err).NotTo(HaveOccurred())
118
119
mockCaller.EXPECT().
120
Post(subject.Config.URL+subject.Config.CompletionsPath, body, false).
121
Return(respBytes, tt.postError)
122
123
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
124
125
_, _, err = subject.Query(context.Background(), query)
126
Expect(err).To(HaveOccurred())
127
Expect(err.Error()).To(ContainSubstring(tt.expectedError))
128
})
129
}
130
131
it("errors when the model is realtime (no HTTP call is made)", func() {
132
factory.withoutHistory()
133
subject := factory.buildClientWithoutConfig()
134
135
realtimeModel := "gpt-realtime"
136
subject.Config.Model = realtimeModel
137
config.Model = realtimeModel
138
139
mockCaller.EXPECT().
140
Post(gomock.Any(), gomock.Any(), gomock.Any()).
141
Times(0)
142
143
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
144
145
_, _, err := subject.Query(context.Background(), query)
146
Expect(err).To(HaveOccurred())
147
148
Expect(err.Error()).To(ContainSubstring("realtime"))
149
})
150
151
it("errors when web is enabled for a non-gpt5 model", func() {
152
factory.withoutHistory()
153
subject := factory.buildClientWithoutConfig()
154
155
subject.Config.Model = "gpt-4o"
156
subject.Config.Web = true
157
158
mockCaller.EXPECT().
159
Post(gomock.Any(), gomock.Any(), gomock.Any()).
160
Times(0)
161
162
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
163
164
_, _, err := subject.Query(context.Background(), query)
165
Expect(err).To(HaveOccurred())
166
Expect(err.Error()).To(ContainSubstring("web search"))
167
})
168
169
it("errors when web is enabled for gpt-5-search", func() {
170
factory.withoutHistory()
171
subject := factory.buildClientWithoutConfig()
172
173
subject.Config.Model = "gpt-5-search"
174
subject.Config.Web = true
175
176
mockCaller.EXPECT().
177
Post(gomock.Any(), gomock.Any(), gomock.Any()).
178
Times(0)
179
180
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
181
182
_, _, err := subject.Query(context.Background(), query)
183
Expect(err).To(HaveOccurred())
184
Expect(err.Error()).To(ContainSubstring("web search"))
185
})
186
187
when("a valid http response is received", func() {
188
testValidHTTPResponse := func(subject *client.Client, expectedBody []byte, omitHistory bool) {
189
const (
190
answer = "content"
191
tokens = 789
192
)
193
194
choice := api.Choice{
195
Message: api.Message{
196
Role: client.AssistantRole,
197
Content: answer,
198
},
199
FinishReason: "",
200
Index: 0,
201
}
202
response := &api.CompletionsResponse{
203
ID: "id",
204
Object: "object",
205
Created: 0,
206
Model: subject.Config.Model,
207
Choices: []api.Choice{choice},
208
Usage: api.Usage{
209
PromptTokens: 123,
210
CompletionTokens: 456,
211
TotalTokens: tokens,
212
},
213
}
214
215
respBytes, err := json.Marshal(response)
216
Expect(err).NotTo(HaveOccurred())
217
218
mockCaller.EXPECT().
219
Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).
220
Return(respBytes, nil)
221
222
var request api.CompletionsRequest
223
err = json.Unmarshal(expectedBody, &request)
224
Expect(err).NotTo(HaveOccurred())
225
226
mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()
227
228
var h []history.History
229
if !omitHistory {
230
for _, msg := range request.Messages {
231
h = append(h, history.History{Message: msg})
232
}
233
234
mockHistoryStore.EXPECT().Write(append(h, history.History{
235
Message: api.Message{
236
Role: client.AssistantRole,
237
Content: answer,
238
},
239
}))
240
}
241
242
result, usage, err := subject.Query(context.Background(), query)
243
Expect(err).NotTo(HaveOccurred())
244
Expect(result).To(Equal(answer))
245
Expect(usage).To(Equal(tokens))
246
}
247
248
it("returns the expected result for a non-empty history", func() {
249
h := []history.History{
250
{Message: api.Message{Role: client.SystemRole, Content: config.Role}},
251
{Message: api.Message{Role: client.UserRole, Content: "question 1"}},
252
{Message: api.Message{Role: client.AssistantRole, Content: "answer 1"}},
253
}
254
255
messages = createMessages(h, query)
256
factory.withHistory(h)
257
subject := factory.buildClientWithoutConfig()
258
259
body, err = createBody(messages, false)
260
Expect(err).NotTo(HaveOccurred())
261
262
testValidHTTPResponse(subject, body, false)
263
})
264
265
it("ignores history when configured to do so", func() {
266
cfg := MockConfig()
267
cfg.OmitHistory = true
268
269
subject := client.New(
270
mockCallerFactory,
271
mockHistoryStore,
272
mockTimer,
273
mockReader,
274
mockWriter,
275
cfg,
276
)
277
278
// History should never be read or written
279
mockHistoryStore.EXPECT().Read().Times(0)
280
mockHistoryStore.EXPECT().Write(gomock.Any()).Times(0)
281
282
var capturedBody []byte
283
284
validHTTPResponseBytes := []byte(`{
285
"id": "chatcmpl_test",
286
"object": "chat.completion",
287
"created": 0,
288
"model": "gpt-4o",
289
"choices": [
290
{
291
"index": 0,
292
"message": { "role": "assistant", "content": "ok" },
293
"finish_reason": "stop"
294
}
295
],
296
"usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 }
297
}`)
298
299
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
300
301
mockCaller.EXPECT().
302
Post(gomock.Any(), gomock.Any(), gomock.Any()).
303
DoAndReturn(func(endpoint string, body []byte, stream bool) ([]byte, error) {
304
capturedBody = body
305
return validHTTPResponseBytes, nil
306
})
307
308
_, _, err := subject.Query(context.Background(), query)
309
Expect(err).NotTo(HaveOccurred())
310
311
Expect(capturedBody).NotTo(BeNil())
312
313
var req api.CompletionsRequest
314
err = json.Unmarshal(capturedBody, &req)
315
Expect(err).NotTo(HaveOccurred())
316
317
Expect(req.Messages).To(HaveLen(1))
318
Expect(req.Messages[0].Role).To(Equal("user"))
319
Expect(req.Messages[0].Content).To(Equal(query))
320
})
321
322
it("truncates the history as expected", func() {
323
hs := []history.History{
324
{Message: api.Message{Role: client.SystemRole, Content: config.Role}, Timestamp: time.Time{}},
325
{Message: api.Message{Role: client.UserRole, Content: "question 1"}, Timestamp: time.Time{}},
326
{Message: api.Message{Role: client.AssistantRole, Content: "answer 1"}, Timestamp: time.Time{}},
327
{Message: api.Message{Role: client.UserRole, Content: "question 2"}, Timestamp: time.Time{}},
328
{Message: api.Message{Role: client.AssistantRole, Content: "answer 2"}, Timestamp: time.Time{}},
329
{Message: api.Message{Role: client.UserRole, Content: "question 3"}, Timestamp: time.Time{}},
330
{Message: api.Message{Role: client.AssistantRole, Content: "answer 3"}, Timestamp: time.Time{}},
331
}
332
333
messages = createMessages(hs, query)
334
335
factory.withHistory(hs)
336
subject := factory.buildClientWithoutConfig()
337
338
// messages get truncated. Index 1+2 are cut out
339
messages = append(messages[:1], messages[3:]...)
340
341
body, err = createBody(messages, false)
342
Expect(err).NotTo(HaveOccurred())
343
344
testValidHTTPResponse(subject, body, false)
345
})
346
347
it("should skip the first message when the model starts with o1Prefix", func() {
348
factory.withHistory([]history.History{
349
{Message: api.Message{Role: client.SystemRole, Content: "First message"}},
350
{Message: api.Message{Role: client.UserRole, Content: "Second message"}},
351
})
352
353
o1Model := "o1-example-model"
354
config.Model = o1Model
355
356
subject := factory.buildClientWithoutConfig()
357
subject.Config.Model = o1Model
358
359
expectedBody, err := createBody([]api.Message{
360
{Role: client.UserRole, Content: "Second message"},
361
{Role: client.UserRole, Content: "test query"},
362
}, false)
363
Expect(err).NotTo(HaveOccurred())
364
365
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
366
mockCaller.EXPECT().
367
Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).
368
Return(nil, nil)
369
370
_, _, _ = subject.Query(context.Background(), "test query")
371
})
372
373
it("should include all messages when the model does not start with o1Prefix", func() {
374
const systemRole = "System role for this test"
375
376
factory.withHistory([]history.History{
377
{Message: api.Message{Role: client.SystemRole, Content: systemRole}},
378
{Message: api.Message{Role: client.UserRole, Content: "Second message"}},
379
})
380
381
regularModel := "gpt-4o"
382
config.Model = regularModel
383
384
subject := factory.buildClientWithoutConfig()
385
subject.Config.Model = regularModel
386
subject.Config.Role = systemRole
387
388
expectedBody, err := createBody([]api.Message{
389
{Role: client.SystemRole, Content: systemRole},
390
{Role: client.UserRole, Content: "Second message"},
391
{Role: client.UserRole, Content: "test query"},
392
}, false)
393
Expect(err).NotTo(HaveOccurred())
394
395
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
396
mockCaller.EXPECT().
397
Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).
398
Return(nil, nil)
399
400
_, _, _ = subject.Query(context.Background(), "test query")
401
})
402
403
it("should omit Temperature and TopP when the model matches SearchModelPattern", func() {
404
searchModel := "gpt-4o-search-preview"
405
config.Model = searchModel
406
config.Role = "role for search test"
407
408
factory.withHistory([]history.History{
409
{Message: api.Message{Role: client.SystemRole, Content: config.Role}},
410
})
411
412
subject := factory.buildClientWithoutConfig()
413
subject.Config.Model = searchModel
414
415
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
416
417
mockCaller.EXPECT().
418
Post(gomock.Any(), gomock.Any(), false).
419
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
420
var req map[string]interface{}
421
Expect(json.Unmarshal(body, &req)).To(Succeed())
422
Expect(req).NotTo(HaveKey("temperature"))
423
Expect(req).NotTo(HaveKey("top_p"))
424
return nil, nil
425
})
426
427
_, _, _ = subject.Query(context.Background(), "test query")
428
})
429
430
it("should include Temperature and TopP when the model does not match SearchModelPattern", func() {
431
regularModel := "gpt-4o"
432
config.Model = regularModel
433
config.Role = "regular model test"
434
435
factory.withHistory([]history.History{
436
{Message: api.Message{Role: client.SystemRole, Content: config.Role}},
437
})
438
439
subject := factory.buildClientWithoutConfig()
440
subject.Config.Model = regularModel
441
442
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
443
444
mockCaller.EXPECT().
445
Post(gomock.Any(), gomock.Any(), false).
446
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
447
var req map[string]interface{}
448
Expect(json.Unmarshal(body, &req)).To(Succeed())
449
450
Expect(req).To(HaveKeyWithValue("temperature", BeNumerically("==", config.Temperature)))
451
Expect(req).To(HaveKeyWithValue("top_p", BeNumerically("==", config.TopP)))
452
return nil, nil
453
})
454
455
_, _, _ = subject.Query(context.Background(), "test query")
456
})
457
458
it("forces Responses API when web is enabled", func() {
459
factory.withoutHistory()
460
subject := factory.buildClientWithoutConfig()
461
462
subject.Config.Model = "gpt-5"
463
subject.Config.Web = true
464
subject.Config.WebContextSize = "low"
465
466
mockTimer.EXPECT().Now().Times(3)
467
mockHistoryStore.EXPECT().Write(gomock.Any())
468
469
response := api.ResponsesResponse{
470
Output: []api.Output{{
471
Type: "message",
472
Content: []api.Content{{Type: "output_text", Text: "hi"}},
473
}},
474
Usage: api.TokenUsage{TotalTokens: 1},
475
}
476
raw, _ := json.Marshal(response)
477
478
mockCaller.EXPECT().
479
Post(subject.Config.URL+"/v1/responses", gomock.Any(), false).
480
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
481
var req map[string]any
482
Expect(json.Unmarshal(body, &req)).To(Succeed())
483
Expect(req).To(HaveKey("tools"))
484
return raw, nil
485
})
486
487
_, _, err := subject.Query(context.Background(), query)
488
Expect(err).NotTo(HaveOccurred())
489
})
490
491
it("adds web_search tool when web is enabled", func() {
492
factory.withoutHistory()
493
subject := factory.buildClientWithoutConfig()
494
495
subject.Config.Model = "gpt-5"
496
subject.Config.Web = true
497
subject.Config.WebContextSize = "low"
498
499
mockTimer.EXPECT().Now().Times(3)
500
mockHistoryStore.EXPECT().Write(gomock.Any())
501
502
response := api.ResponsesResponse{
503
Output: []api.Output{{
504
Type: "message",
505
Content: []api.Content{{Type: "output_text", Text: "ok"}},
506
}},
507
Usage: api.TokenUsage{TotalTokens: 1},
508
}
509
raw, _ := json.Marshal(response)
510
511
mockCaller.EXPECT().
512
Post(subject.Config.URL+"/v1/responses", gomock.Any(), false).
513
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
514
var req map[string]any
515
Expect(json.Unmarshal(body, &req)).To(Succeed())
516
517
tools := req["tools"].([]any)
518
Expect(tools).To(HaveLen(1))
519
520
tool := tools[0].(map[string]any)
521
Expect(tool).To(HaveKeyWithValue("type", "web_search"))
522
Expect(tool).To(HaveKeyWithValue("search_context_size", "low"))
523
524
return raw, nil
525
})
526
527
_, _, err := subject.Query(context.Background(), query)
528
Expect(err).NotTo(HaveOccurred())
529
})
530
})
531
532
when("the model is o1-pro or gpt-5", func() {
533
models := []string{"o1-pro", "gpt-5"}
534
535
for _, m := range models {
536
m := m
537
when(fmt.Sprintf("the model is %s", m), func() {
538
const (
539
query = "what's the weather"
540
systemRole = "you are helpful"
541
totalTokens = 777
542
)
543
544
it.Before(func() {
545
config.Model = m
546
config.Role = systemRole
547
factory.withoutHistory()
548
})
549
550
assertResponsesRequest := func(body []byte) {
551
var req map[string]any
552
Expect(json.Unmarshal(body, &req)).To(Succeed())
553
554
Expect(req).To(HaveKeyWithValue("model", m))
555
Expect(req).To(HaveKey("input"))
556
Expect(req).To(HaveKey("max_output_tokens"))
557
Expect(req).To(HaveKey("reasoning"))
558
Expect(req).To(HaveKeyWithValue("stream", false))
559
560
// Validate reasoning.effort
561
reasoning, ok := req["reasoning"].(map[string]any)
562
Expect(ok).To(BeTrue())
563
Expect(reasoning).To(HaveKeyWithValue("effort", "low"))
564
565
// Validate input messages
566
input, ok := req["input"].([]any)
567
Expect(ok).To(BeTrue())
568
Expect(input).To(HaveLen(2))
569
570
msg0, ok := input[0].(map[string]any)
571
Expect(ok).To(BeTrue())
572
Expect(msg0).To(HaveKeyWithValue("role", client.SystemRole))
573
Expect(msg0).To(HaveKeyWithValue("content", systemRole))
574
575
msg1, ok := input[1].(map[string]any)
576
Expect(ok).To(BeTrue())
577
Expect(msg1).To(HaveKeyWithValue("role", client.UserRole))
578
Expect(msg1).To(HaveKeyWithValue("content", query))
579
580
// Temperature / top_p assertions are capability-driven now
581
caps := client.GetCapabilities(m)
582
583
if caps.SupportsTemperature {
584
Expect(req).To(HaveKeyWithValue("temperature", BeNumerically("==", config.Temperature)))
585
} else {
586
Expect(req).NotTo(HaveKey("temperature"))
587
}
588
589
if caps.SupportsTopP {
590
Expect(req).To(HaveKeyWithValue("top_p", BeNumerically("==", config.TopP)))
591
} else {
592
Expect(req).NotTo(HaveKey("top_p"))
593
}
594
}
595
596
it("returns the output_text when present", func() {
597
subject := factory.buildClientWithoutConfig()
598
subject.Config.Model = m
599
subject.Config.Role = systemRole
600
601
answer := "yes, it does"
602
603
mockTimer.EXPECT().Now().Times(3)
604
mockHistoryStore.EXPECT().Write(gomock.Any())
605
606
response := api.ResponsesResponse{
607
Output: []api.Output{{
608
Type: "message",
609
Content: []api.Content{{Type: "output_text", Text: answer}},
610
}},
611
Usage: api.TokenUsage{TotalTokens: 42},
612
}
613
raw, _ := json.Marshal(response)
614
615
mockCaller.EXPECT().
616
Post(subject.Config.URL+"/v1/responses", gomock.Any(), false).
617
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
618
assertResponsesRequest(body)
619
return raw, nil
620
})
621
622
text, tokens, err := subject.Query(context.Background(), query)
623
Expect(err).NotTo(HaveOccurred())
624
Expect(text).To(Equal(answer))
625
Expect(tokens).To(Equal(42))
626
})
627
628
it("errors when no output blocks are present", func() {
629
subject := factory.buildClientWithoutConfig()
630
subject.Config.Model = m
631
subject.Config.Role = systemRole
632
633
mockTimer.EXPECT().Now().Times(2)
634
635
response := api.ResponsesResponse{
636
Output: []api.Output{},
637
Usage: api.TokenUsage{TotalTokens: totalTokens},
638
}
639
raw, _ := json.Marshal(response)
640
641
mockCaller.EXPECT().
642
Post(subject.Config.URL+"/v1/responses", gomock.Any(), false).
643
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
644
assertResponsesRequest(body)
645
return raw, nil
646
})
647
648
_, _, err := subject.Query(context.Background(), query)
649
Expect(err).To(HaveOccurred())
650
Expect(err.Error()).To(Equal("no response returned"))
651
})
652
653
it("errors when message has no output_text", func() {
654
subject := factory.buildClientWithoutConfig()
655
subject.Config.Model = m
656
subject.Config.Role = systemRole
657
658
mockTimer.EXPECT().Now().Times(2)
659
660
response := api.ResponsesResponse{
661
Output: []api.Output{{
662
Type: "message",
663
Content: []api.Content{{Type: "refusal", Text: "nope"}},
664
}},
665
Usage: api.TokenUsage{TotalTokens: totalTokens},
666
}
667
raw, _ := json.Marshal(response)
668
669
mockCaller.EXPECT().
670
Post(subject.Config.URL+"/v1/responses", gomock.Any(), false).
671
DoAndReturn(func(_ string, body []byte, _ bool) ([]byte, error) {
672
assertResponsesRequest(body)
673
return raw, nil
674
})
675
676
_, _, err := subject.Query(context.Background(), query)
677
Expect(err).To(HaveOccurred())
678
Expect(err.Error()).To(Equal("no response returned"))
679
})
680
})
681
}
682
})
683
})
684
685
when("Stream()", func() {
686
var (
687
body []byte
688
messages []api.Message
689
err error
690
)
691
692
it("throws an error when the http callout fails", func() {
693
factory.withoutHistory()
694
subject := factory.buildClientWithoutConfig()
695
696
messages = createMessages(nil, query)
697
body, err = createBody(messages, true)
698
Expect(err).NotTo(HaveOccurred())
699
700
errorMsg := "error message"
701
mockCaller.EXPECT().
702
Post(subject.Config.URL+subject.Config.CompletionsPath, body, true).
703
Return(nil, errors.New(errorMsg))
704
705
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
706
707
err := subject.Stream(context.Background(), query)
708
Expect(err).To(HaveOccurred())
709
Expect(err.Error()).To(Equal(errorMsg))
710
})
711
712
it("errors when the model is realtime (no HTTP call is made)", func() {
713
factory.withoutHistory()
714
subject := factory.buildClientWithoutConfig()
715
716
realtimeModel := "gpt-4o-realtime-preview"
717
subject.Config.Model = realtimeModel
718
config.Model = realtimeModel
719
720
mockCaller.EXPECT().
721
Post(gomock.Any(), gomock.Any(), gomock.Any()).
722
Times(0)
723
724
mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)
725
726
err := subject.Stream(context.Background(), query)
727
Expect(err).To(HaveOccurred())
728
Expect(err.Error()).To(ContainSubstring("realtime"))
729
})
730
731
when("a valid http response is received", func() {
732
const answer = "answer"
733
734
testValidHTTPResponse := func(subject *client.Client, hs []history.History, expectedBody []byte) {
735
messages = createMessages(nil, query)
736
body, err = createBody(messages, true)
737
Expect(err).NotTo(HaveOccurred())
738
739
mockCaller.EXPECT().
740
Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, true).
741
Return([]byte(answer), nil)
742
743
mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()
744
745
messages = createMessages(hs, query)
746
747
var out []history.History
748
for _, message := range messages {
749
out = append(out, history.History{Message: message})
750
}
751
752
mockHistoryStore.EXPECT().Write(append(out, history.History{
753
Message: api.Message{
754
Role: client.AssistantRole,
755
Content: answer,
756
},
757
}))
758
759
err := subject.Stream(context.Background(), query)
760
Expect(err).NotTo(HaveOccurred())
761
}
762
763
it("returns the expected result for an empty history", func() {
764
factory.withHistory(nil)
765
subject := factory.buildClientWithoutConfig()
766
767
messages = createMessages(nil, query)
768
body, err = createBody(messages, true)
769
Expect(err).NotTo(HaveOccurred())
770
771
testValidHTTPResponse(subject, nil, body)
772
})
773
774
it("returns the expected result for a non-empty history", func() {
775
h := []history.History{
776
{Message: api.Message{Role: client.SystemRole, Content: config.Role}},
777
{Message: api.Message{Role: client.UserRole, Content: "question x"}},
778
{Message: api.Message{Role: client.AssistantRole, Content: "answer x"}},
779
}
780
factory.withHistory(h)
781
subject := factory.buildClientWithoutConfig()
782
783
messages = createMessages(h, query)
784
body, err = createBody(messages, true)
785
Expect(err).NotTo(HaveOccurred())
786
787
testValidHTTPResponse(subject, h, body)
788
})
789
})
790
})
791
792
when("ListModels()", func() {
793
it("throws an error when the http callout fails", func() {
794
subject := factory.buildClientWithoutConfig()
795
796
errorMsg := "error message"
797
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).
798
Return(nil, errors.New(errorMsg))
799
800
_, err := subject.ListModels()
801
Expect(err).To(HaveOccurred())
802
Expect(err.Error()).To(Equal(errorMsg))
803
})
804
805
it("throws an error when the response is empty", func() {
806
subject := factory.buildClientWithoutConfig()
807
808
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).Return(nil, nil)
809
810
_, err := subject.ListModels()
811
Expect(err).To(HaveOccurred())
812
Expect(err.Error()).To(Equal("empty response"))
813
})
814
815
it("throws an error when the response is a malformed json", func() {
816
subject := factory.buildClientWithoutConfig()
817
818
malformed := `{"invalid":"json"` // missing closing brace
819
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).
820
Return([]byte(malformed), nil)
821
822
_, err := subject.ListModels()
823
Expect(err).To(HaveOccurred())
824
Expect(err.Error()).Should(HavePrefix("failed to decode response:"))
825
})
826
827
it("filters gpt and o1 models as expected and puts them in alphabetical order", func() {
828
subject := factory.buildClientWithoutConfig()
829
830
response, err := test.FileToBytes("models.json")
831
Expect(err).NotTo(HaveOccurred())
832
833
mockCaller.EXPECT().Get(subject.Config.URL+subject.Config.ModelsPath).
834
Return(response, nil)
835
836
result, err := subject.ListModels()
837
Expect(err).NotTo(HaveOccurred())
838
Expect(result).NotTo(BeEmpty())
839
Expect(result).To(HaveLen(5))
840
Expect(result[0]).To(Equal("- gpt-3.5-env-model"))
841
Expect(result[1]).To(Equal("* gpt-3.5-turbo (current)"))
842
Expect(result[2]).To(Equal("- gpt-3.5-turbo-0301"))
843
Expect(result[3]).To(Equal("- gpt-4o"))
844
Expect(result[4]).To(Equal("- o1-mini"))
845
})
846
})
847
})
848
}
849
850
func testCapabilities(t *testing.T, when spec.G, it spec.S) {
851
when("GetCapabilities()", func() {
852
type tc struct {
853
model string
854
855
supportsTemp bool
856
supportsTopP bool
857
usesResponses bool
858
omitFirstSystem bool
859
supportsStreaming bool
860
isRealtime bool
861
supportsWebSearch bool
862
}
863
864
tests := []tc{
865
{
866
model: "gpt-4o",
867
supportsTemp: true,
868
supportsTopP: true,
869
usesResponses: false,
870
omitFirstSystem: false,
871
supportsStreaming: true,
872
isRealtime: false,
873
supportsWebSearch: false,
874
},
875
{
876
model: "gpt-4o-search-preview",
877
supportsTemp: false,
878
supportsTopP: false,
879
usesResponses: false, // still completions path
880
omitFirstSystem: false,
881
supportsStreaming: true,
882
isRealtime: false,
883
supportsWebSearch: false,
884
},
885
{
886
model: "gpt-realtime",
887
isRealtime: true,
888
},
889
{
890
model: "gpt-5",
891
supportsTemp: true,
892
supportsTopP: false,
893
usesResponses: true,
894
omitFirstSystem: false,
895
supportsStreaming: true,
896
isRealtime: false,
897
supportsWebSearch: true,
898
},
899
{
900
model: "gpt-5-search",
901
supportsTemp: false,
902
supportsTopP: false,
903
usesResponses: true,
904
omitFirstSystem: false,
905
supportsStreaming: true,
906
isRealtime: false,
907
supportsWebSearch: false,
908
},
909
{
910
model: "gpt-5.2",
911
supportsTemp: true,
912
supportsTopP: false,
913
usesResponses: true,
914
omitFirstSystem: false,
915
supportsStreaming: true,
916
isRealtime: false,
917
supportsWebSearch: true,
918
},
919
{
920
model: "gpt-5.2-pro",
921
supportsTemp: true,
922
supportsTopP: false,
923
usesResponses: true,
924
omitFirstSystem: false,
925
supportsStreaming: true,
926
isRealtime: false,
927
supportsWebSearch: true,
928
},
929
{
930
model: "o1-mini",
931
supportsTemp: true,
932
supportsTopP: true,
933
usesResponses: false,
934
omitFirstSystem: true,
935
supportsStreaming: true,
936
isRealtime: false,
937
supportsWebSearch: false,
938
},
939
{
940
model: "o1-pro",
941
supportsTemp: true,
942
supportsTopP: true,
943
usesResponses: true,
944
omitFirstSystem: false,
945
supportsStreaming: false,
946
isRealtime: false,
947
supportsWebSearch: false,
948
},
949
}
950
951
for _, tt := range tests {
952
tt := tt
953
it(tt.model, func() {
954
RegisterTestingT(t)
955
956
c := client.GetCapabilities(tt.model)
957
958
Expect(c.IsRealtime).To(Equal(tt.isRealtime))
959
960
// Only assert these for non-realtime models.
961
if !tt.isRealtime {
962
Expect(c.SupportsTemperature).To(Equal(tt.supportsTemp))
963
Expect(c.SupportsTopP).To(Equal(tt.supportsTopP))
964
Expect(c.UsesResponsesAPI).To(Equal(tt.usesResponses))
965
Expect(c.OmitFirstSystemMsg).To(Equal(tt.omitFirstSystem))
966
Expect(c.SupportsStreaming).To(Equal(tt.supportsStreaming))
967
Expect(c.SupportsWebSearch).To(Equal(tt.supportsWebSearch))
968
}
969
})
970
}
971
})
972
}
973
974
func createBody(messages []api.Message, stream bool) ([]byte, error) {
975
req := api.CompletionsRequest{
976
Model: config.Model,
977
Messages: messages,
978
Stream: stream,
979
Temperature: config.Temperature,
980
TopP: config.TopP,
981
FrequencyPenalty: config.FrequencyPenalty,
982
MaxTokens: config.MaxTokens,
983
PresencePenalty: config.PresencePenalty,
984
Seed: config.Seed,
985
}
986
987
return json.Marshal(req)
988
}
989
990
func createMessages(historyEntries []history.History, query string) []api.Message {
991
var messages []api.Message
992
993
if len(historyEntries) == 0 {
994
messages = append(messages, api.Message{
995
Role: client.SystemRole,
996
Content: config.Role,
997
})
998
} else {
999
for _, entry := range historyEntries {
1000
messages = append(messages, entry.Message)
1001
}
1002
}
1003
1004
messages = append(messages, api.Message{
1005
Role: client.UserRole,
1006
Content: query,
1007
})
1008
1009
return messages
1010
}
1011
1012