Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/mcp_test.go
3431 views
1
package client_test
2
3
import (
4
"errors"
5
"github.com/kardolus/chatgpt-cli/api/http"
6
"strings"
7
"testing"
8
"time"
9
10
"github.com/golang/mock/gomock"
11
"github.com/kardolus/chatgpt-cli/api"
12
"github.com/kardolus/chatgpt-cli/api/client"
13
"github.com/kardolus/chatgpt-cli/history"
14
15
. "github.com/onsi/gomega"
16
"github.com/sclevine/spec"
17
)
18
19
// This file only contains MCP-specific tests.
20
// It relies on the shared test setup (mocks/factory/config) that lives in client_test.go.
21
22
func testMCP(t *testing.T, when spec.G, it spec.S) {
23
when("InjectMCPContext()", func() {
24
var (
25
subject *client.Client
26
mockMCPTransport *MockMCPTransport
27
)
28
29
const (
30
tool = "mock-tool"
31
endpoint = "https://example.com/mcp"
32
)
33
34
newReq := func() api.MCPRequest {
35
return api.MCPRequest{
36
Endpoint: endpoint,
37
Tool: tool,
38
Headers: map[string]string{},
39
Params: map[string]interface{}{
40
"mock-param": "mock-value",
41
},
42
}
43
}
44
45
it.Before(func() {
46
subject = factory.buildClientWithoutConfig()
47
48
mockMCPTransport = NewMockMCPTransport(mockCtrl)
49
subject = subject.WithTransport(mockMCPTransport)
50
subject = subject.WithContextWindow(1000)
51
})
52
53
it("throws an error when history tracking is disabled", func() {
54
subject.Config.OmitHistory = true
55
56
err := subject.InjectMCPContext(newReq())
57
Expect(err).To(HaveOccurred())
58
Expect(err).To(MatchError(client.ErrHistoryTracking))
59
})
60
61
it("throws an error when mcp endpoint is missing", func() {
62
r := newReq()
63
r.Endpoint = ""
64
65
err := subject.InjectMCPContext(r)
66
Expect(err).To(HaveOccurred())
67
Expect(err.Error()).To(ContainSubstring("mcp endpoint is required"))
68
})
69
70
it("throws an error when mcp tool is missing", func() {
71
r := newReq()
72
r.Tool = ""
73
74
err := subject.InjectMCPContext(r)
75
Expect(err).To(HaveOccurred())
76
Expect(err.Error()).To(ContainSubstring("mcp tool is required"))
77
})
78
79
it("throws an error when the transport call fails", func() {
80
r := newReq()
81
msg := "transport error"
82
83
mockMCPTransport.EXPECT().
84
Call(endpoint, gomock.Any(), r.Headers).
85
Return(api.MCPResponse{}, errors.New(msg))
86
87
err := subject.InjectMCPContext(r)
88
Expect(err).To(HaveOccurred())
89
Expect(err).To(MatchError(msg))
90
})
91
92
it("throws an error when history writing fails", func() {
93
r := newReq()
94
95
resp := api.MCPMessage{
96
JSONRPC: "2.0",
97
ID: "2",
98
Result: []byte(`{"content":[{"type":"text","text":"ok"}]}`),
99
}
100
101
mockMCPTransport.EXPECT().
102
Call(endpoint, gomock.Any(), gomock.Any()).
103
Return(api.MCPResponse{Message: resp}, nil)
104
105
mockHistoryStore.EXPECT().Read().Times(1)
106
mockTimer.EXPECT().Now().Times(2)
107
108
msg := "write error"
109
mockHistoryStore.EXPECT().Write(gomock.Any()).Return(errors.New(msg))
110
111
err := subject.InjectMCPContext(r)
112
Expect(err).To(HaveOccurred())
113
Expect(err).To(MatchError(msg))
114
})
115
116
it("adds the formatted MCP response to history (single text block, JSON string gets pretty-printed)", func() {
117
r := newReq()
118
119
resp := api.MCPMessage{
120
JSONRPC: "2.0",
121
ID: "2",
122
Result: []byte(`{
123
"content": [
124
{"type": "text", "text": "[{\"temperature\":\"15C\",\"condition\":\"Sunny\"}]"}
125
]
126
}`),
127
}
128
129
mockMCPTransport.EXPECT().
130
Call(endpoint, gomock.Any(), gomock.Any()).
131
Return(api.MCPResponse{Message: resp}, nil)
132
133
mockHistoryStore.EXPECT().Read().Times(1)
134
mockTimer.EXPECT().Now().Times(2)
135
136
mockHistoryStore.EXPECT().Write(gomock.Any()).
137
DoAndReturn(func(h []history.History) error {
138
Expect(h).NotTo(BeEmpty())
139
last := h[len(h)-1]
140
141
Expect(last.Message.Role).To(Equal(client.AssistantRole))
142
Expect(last.Message.Name).To(BeEmpty()) // name is no longer set
143
Expect(last.Message.Content).To(ContainSubstring("[MCP: " + tool + "]"))
144
145
// Pretty-printed JSON produced by normalizeMaybeJSON()
146
Expect(last.Message.Content).To(ContainSubstring(`"temperature": "15C"`))
147
Expect(last.Message.Content).To(ContainSubstring(`"condition": "Sunny"`))
148
149
return nil
150
})
151
152
err := subject.InjectMCPContext(r)
153
Expect(err).NotTo(HaveOccurred())
154
})
155
156
it("joins multiple text blocks with a blank line between them", func() {
157
r := newReq()
158
159
resp := api.MCPMessage{
160
JSONRPC: "2.0",
161
ID: "2",
162
Result: []byte(`{
163
"content": [
164
{"type":"text","text":"first"},
165
{"type":"text","text":"second"}
166
]
167
}`),
168
}
169
170
mockMCPTransport.EXPECT().
171
Call(endpoint, gomock.Any(), gomock.Any()).
172
Return(api.MCPResponse{Message: resp}, nil)
173
174
mockHistoryStore.EXPECT().Read().Times(1)
175
mockTimer.EXPECT().Now().Times(2)
176
177
mockHistoryStore.EXPECT().Write(gomock.Any()).
178
DoAndReturn(func(h []history.History) error {
179
last := h[len(h)-1].Message.Content
180
Expect(last).To(ContainSubstring("[MCP: " + tool + "]"))
181
Expect(last).To(ContainSubstring("first\n\nsecond"))
182
return nil
183
})
184
185
err := subject.InjectMCPContext(r)
186
Expect(err).NotTo(HaveOccurred())
187
})
188
189
it("falls back to '(empty result)' when resp.Result is empty", func() {
190
r := newReq()
191
192
resp := api.MCPMessage{
193
JSONRPC: "2.0",
194
ID: "2",
195
Result: nil,
196
}
197
198
mockMCPTransport.EXPECT().
199
Call(endpoint, gomock.Any(), gomock.Any()).
200
Return(api.MCPResponse{Message: resp}, nil)
201
202
mockHistoryStore.EXPECT().Read().Times(1)
203
mockTimer.EXPECT().Now().Times(2)
204
205
mockHistoryStore.EXPECT().Write(gomock.Any()).
206
DoAndReturn(func(h []history.History) error {
207
Expect(h[len(h)-1].Message.Content).To(ContainSubstring("(empty result)"))
208
return nil
209
})
210
211
err := subject.InjectMCPContext(r)
212
Expect(err).NotTo(HaveOccurred())
213
})
214
215
_ = time.Time{}
216
})
217
}
218
219
func testSessionTransport(t *testing.T, when spec.G, it spec.S) {
220
when("SessionTransport.Call()", func() {
221
var (
222
endpoint string
223
store *fakeSessionStore
224
inner *fakeMCPTransport
225
subject *client.SessionTransport
226
)
227
228
it.Before(func() {
229
RegisterTestingT(t)
230
231
endpoint = "https://example.com/mcp"
232
store = newFakeSessionStore()
233
inner = &fakeMCPTransport{}
234
subject = client.NewSessionTransport(inner, store)
235
})
236
237
it("passthrough when caller explicitly sets mcp-session-id header (no store access)", func() {
238
req := api.MCPMessage{JSONRPC: "2.0", ID: "1", Method: "tools/call", Params: []byte(`{}`)}
239
240
headers := map[string]string{"Mcp-Session-Id": "explicit-sid"}
241
242
inner.handler = func(ep string, r api.MCPMessage, h map[string]string) (api.MCPResponse, error) {
243
Expect(ep).To(Equal(endpoint))
244
Expect(r.Method).To(Equal("tools/call"))
245
Expect(headerGetCI(h, "mcp-session-id")).To(Equal("explicit-sid"))
246
return api.MCPResponse{Status: 200, Headers: map[string]string{}}, nil
247
}
248
249
_, err := subject.Call(endpoint, req, headers)
250
Expect(err).NotTo(HaveOccurred())
251
252
Expect(store.getCalls).To(Equal(0))
253
Expect(store.setCalls).To(Equal(0))
254
Expect(store.delCalls).To(Equal(0))
255
})
256
257
it("attaches cached session id when caller did not provide one", func() {
258
store.sessions[endpoint] = "cached-sid"
259
260
req := api.MCPMessage{JSONRPC: "2.0", ID: "1", Method: "tools/call", Params: []byte(`{}`)}
261
headers := map[string]string{}
262
263
inner.handler = func(ep string, r api.MCPMessage, h map[string]string) (api.MCPResponse, error) {
264
Expect(headerGetCI(h, "mcp-session-id")).To(Equal("cached-sid"))
265
return api.MCPResponse{Status: 200, Headers: map[string]string{}}, nil
266
}
267
268
_, err := subject.Call(endpoint, req, headers)
269
Expect(err).NotTo(HaveOccurred())
270
271
Expect(store.getCalls).To(Equal(1))
272
})
273
274
it("stores rotated session id when server returns mcp-session-id header", func() {
275
store.sessions[endpoint] = "cached-sid"
276
277
req := api.MCPMessage{JSONRPC: "2.0", ID: "1", Method: "tools/call", Params: []byte(`{}`)}
278
headers := map[string]string{}
279
280
inner.handler = func(ep string, r api.MCPMessage, h map[string]string) (api.MCPResponse, error) {
281
Expect(headerGetCI(h, "mcp-session-id")).To(Equal("cached-sid"))
282
return api.MCPResponse{
283
Status: 200,
284
Headers: map[string]string{"mcp-session-id": "rotated-sid"},
285
}, nil
286
}
287
288
_, err := subject.Call(endpoint, req, headers)
289
Expect(err).NotTo(HaveOccurred())
290
291
Expect(store.sessions[endpoint]).To(Equal("rotated-sid"))
292
Expect(store.setCalls).To(Equal(1))
293
})
294
295
it("on invalid session: deletes cached session, initializes, retries once with new session", func() {
296
store.sessions[endpoint] = "old-sid"
297
298
origReq := api.MCPMessage{JSONRPC: "2.0", ID: "orig", Method: "tools/call", Params: []byte(`{}`)}
299
headers := map[string]string{}
300
301
callCount := 0
302
303
inner.handler = func(ep string, r api.MCPMessage, h map[string]string) (api.MCPResponse, error) {
304
callCount++
305
306
switch callCount {
307
case 1:
308
// First attempt uses cached sid and fails with "invalid session"
309
Expect(r.Method).To(Equal("tools/call"))
310
Expect(headerGetCI(h, "mcp-session-id")).To(Equal("old-sid"))
311
return api.MCPResponse{
312
Status: 400,
313
Headers: map[string]string{},
314
Message: api.MCPMessage{
315
JSONRPC: "2.0",
316
ID: "server-error",
317
Error: &api.MCPError{
318
Message: "Bad Request: No valid session ID provided",
319
Code: "-32600",
320
},
321
},
322
}, errors.New("Bad Request: No valid session ID provided")
323
case 2:
324
// initialize call should happen next (no session header)
325
Expect(r.Method).To(Equal("initialize"))
326
_, ok := headerGetCIok(h, "mcp-session-id")
327
Expect(ok).To(BeFalse())
328
329
return api.MCPResponse{
330
Status: 200,
331
Headers: map[string]string{"mcp-session-id": "new-sid"},
332
Message: api.MCPMessage{
333
JSONRPC: "2.0",
334
ID: r.ID,
335
Result: []byte(`{}`),
336
},
337
}, nil
338
case 3:
339
// retry original request with new session id
340
Expect(r.Method).To(Equal("tools/call"))
341
Expect(headerGetCI(h, "mcp-session-id")).To(Equal("new-sid"))
342
343
return api.MCPResponse{
344
Status: 200,
345
Headers: map[string]string{},
346
Message: api.MCPMessage{
347
JSONRPC: "2.0",
348
ID: "ok",
349
Result: []byte(`{"content":[{"type":"text","text":"ok"}]}`),
350
},
351
}, nil
352
default:
353
return api.MCPResponse{}, errors.New("unexpected extra call")
354
}
355
}
356
357
resp, err := subject.Call(endpoint, origReq, headers)
358
Expect(err).NotTo(HaveOccurred())
359
Expect(resp.Status).To(Equal(200))
360
361
Expect(store.delCalls).To(Equal(1))
362
Expect(store.sessions[endpoint]).To(Equal("new-sid"))
363
Expect(callCount).To(Equal(3))
364
})
365
366
it("errors if initialize succeeds but does not return mcp-session-id header", func() {
367
// no cached session
368
req := api.MCPMessage{JSONRPC: "2.0", ID: "orig", Method: "tools/call", Params: []byte(`{}`)}
369
headers := map[string]string{}
370
371
callCount := 0
372
inner.handler = func(ep string, r api.MCPMessage, h map[string]string) (api.MCPResponse, error) {
373
callCount++
374
if callCount == 1 {
375
Expect(r.Method).To(Equal("initialize"))
376
return api.MCPResponse{
377
Status: 200,
378
Headers: map[string]string{}, // missing session header
379
Message: api.MCPMessage{JSONRPC: "2.0", ID: r.ID, Result: []byte(`{}`)},
380
}, nil
381
}
382
return api.MCPResponse{}, errors.New("should not reach retry")
383
}
384
385
_, err := subject.Call(endpoint, req, headers)
386
Expect(err).To(HaveOccurred())
387
Expect(err.Error()).To(ContainSubstring("did not return"))
388
})
389
})
390
}
391
392
func testSessionTransportNonHTTP(t *testing.T, when spec.G, it spec.S) {
393
when("SessionTransport.Call() with non-http scheme", func() {
394
var (
395
endpoint string
396
store *fakeSessionStore
397
inner *fakeMCPTransport
398
subject *client.SessionTransport
399
)
400
401
it.Before(func() {
402
RegisterTestingT(t)
403
404
endpoint = "stdio:python test/mcp/stdio/mcp_stdio_server.py"
405
store = newFakeSessionStore()
406
inner = &fakeMCPTransport{}
407
subject = client.NewSessionTransport(inner, store)
408
})
409
410
it("bypasses session logic and does not touch the session store", func() {
411
req := api.MCPMessage{JSONRPC: "2.0", ID: "1", Method: "tools/call", Params: []byte(`{}`)}
412
headers := map[string]string{}
413
414
inner.handler = func(ep string, r api.MCPMessage, h map[string]string) (api.MCPResponse, error) {
415
Expect(ep).To(Equal(endpoint))
416
Expect(r.Method).To(Equal("tools/call"))
417
// No session header should be injected for non-http transports.
418
_, ok := headerGetCIok(h, "mcp-session-id")
419
Expect(ok).To(BeFalse())
420
421
return api.MCPResponse{
422
Status: 0,
423
Headers: nil,
424
Message: api.MCPMessage{JSONRPC: "2.0", ID: r.ID, Result: []byte(`{}`)},
425
}, nil
426
}
427
428
_, err := subject.Call(endpoint, req, headers)
429
Expect(err).NotTo(HaveOccurred())
430
431
Expect(store.getCalls).To(Equal(0))
432
Expect(store.setCalls).To(Equal(0))
433
Expect(store.delCalls).To(Equal(0))
434
})
435
})
436
}
437
438
func testNewMCPTransport(t *testing.T, when spec.G, it spec.S) {
439
when("NewMCPTransport()", func() {
440
it.Before(func() {
441
RegisterTestingT(t)
442
})
443
444
it("returns MCPHTTPTransport for http/https endpoints", func() {
445
// We don't need to actually call it; we just want to route correctly.
446
var caller http.Caller = nil
447
448
tr, err := client.NewMCPTransport("https://example.com/mcp", caller, map[string]string{})
449
Expect(err).NotTo(HaveOccurred())
450
Expect(tr).To(BeAssignableToTypeOf(&client.MCPHTTPTransport{}))
451
452
tr, err = client.NewMCPTransport("http://example.com/mcp", caller, map[string]string{})
453
Expect(err).NotTo(HaveOccurred())
454
Expect(tr).To(BeAssignableToTypeOf(&client.MCPHTTPTransport{}))
455
})
456
457
it("returns MCPStdioTransport for stdio endpoints", func() {
458
var caller http.Caller = nil
459
460
tr, err := client.NewMCPTransport("stdio:python test/mcp/stdio/mcp_stdio_server.py", caller, map[string]string{})
461
Expect(err).NotTo(HaveOccurred())
462
Expect(tr).To(BeAssignableToTypeOf(&client.MCPStdioTransport{}))
463
})
464
465
it("errors for unsupported schemes", func() {
466
var caller http.Caller = nil
467
468
_, err := client.NewMCPTransport("ftp://example.com/mcp", caller, map[string]string{})
469
Expect(err).To(HaveOccurred())
470
Expect(err.Error()).To(ContainSubstring("unsupported mcp transport"))
471
})
472
})
473
}
474
475
/* =========================
476
Fakes
477
========================= */
478
479
type fakeSessionStore struct {
480
sessions map[string]string
481
getCalls int
482
setCalls int
483
delCalls int
484
}
485
486
func newFakeSessionStore() *fakeSessionStore {
487
return &fakeSessionStore{sessions: map[string]string{}}
488
}
489
490
func (s *fakeSessionStore) GetSessionID(endpoint string) (string, error) {
491
s.getCalls++
492
return s.sessions[endpoint], nil
493
}
494
495
func (s *fakeSessionStore) SetSessionID(endpoint, sessionID string) error {
496
s.setCalls++
497
s.sessions[endpoint] = sessionID
498
return nil
499
}
500
501
func (s *fakeSessionStore) DeleteSessionID(endpoint string) error {
502
s.delCalls++
503
delete(s.sessions, endpoint)
504
return nil
505
}
506
507
type fakeMCPTransport struct {
508
handler func(endpoint string, req api.MCPMessage, headers map[string]string) (api.MCPResponse, error)
509
}
510
511
func (t *fakeMCPTransport) Call(endpoint string, req api.MCPMessage, headers map[string]string) (api.MCPResponse, error) {
512
if t.handler == nil {
513
return api.MCPResponse{}, errors.New("fakeMCPTransport.handler is nil")
514
}
515
// clone headers to avoid accidental mutation surprises across calls
516
h := map[string]string{}
517
for k, v := range headers {
518
h[k] = v
519
}
520
return t.handler(endpoint, req, h)
521
}
522
523
/* =========================
524
Header helpers (case-insensitive)
525
========================= */
526
527
func headerGetCI(h map[string]string, key string) string {
528
v, _ := headerGetCIok(h, key)
529
return v
530
}
531
532
func headerGetCIok(h map[string]string, key string) (string, bool) {
533
for k, v := range h {
534
if strings.EqualFold(k, key) {
535
return v, true
536
}
537
}
538
return "", false
539
}
540
541