Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/mcp.go
3431 views
1
package client
2
3
import (
4
"bufio"
5
"bytes"
6
"encoding/json"
7
"errors"
8
"fmt"
9
"io"
10
"net/url"
11
"os/exec"
12
"strings"
13
"sync"
14
"time"
15
16
"github.com/google/uuid"
17
18
"github.com/kardolus/chatgpt-cli/api"
19
"github.com/kardolus/chatgpt-cli/api/http"
20
"github.com/kardolus/chatgpt-cli/history"
21
)
22
23
/* =========================
24
Client entrypoint
25
========================= */
26
27
func (c *Client) InjectMCPContext(mcp api.MCPRequest) error {
28
if c.Config.OmitHistory {
29
return errors.New(ErrHistoryTracking)
30
}
31
if mcp.Endpoint == "" {
32
return fmt.Errorf("mcp endpoint is required")
33
}
34
if mcp.Tool == "" {
35
return fmt.Errorf("mcp tool is required")
36
}
37
38
req, err := c.buildMCPMessage(mcp)
39
if err != nil {
40
return err
41
}
42
43
if rawReq, err := json.Marshal(req); err == nil {
44
c.printRequestDebugInfo(mcp.Endpoint, rawReq, buildMCPHeaders(mcp.Headers))
45
}
46
47
resp, err := c.transport.Call(mcp.Endpoint, req, mcp.Headers)
48
if err != nil {
49
return err
50
}
51
52
if rawResp, err := json.Marshal(resp.Message); err == nil {
53
c.printResponseDebugInfo(rawResp)
54
}
55
56
formatted := formatMCPResponse(resp.Message.Result, mcp.Tool)
57
58
c.initHistory()
59
c.History = append(c.History, history.History{
60
Message: api.Message{
61
Role: AssistantRole,
62
Content: formatted,
63
},
64
Timestamp: c.timer.Now(),
65
})
66
c.truncateHistory()
67
68
return c.historyStore.Write(c.History)
69
}
70
71
/* =========================
72
MCP message building
73
========================= */
74
75
func (c *Client) buildMCPMessage(mcp api.MCPRequest) (api.MCPMessage, error) {
76
rawParams, err := json.Marshal(map[string]any{
77
"name": mcp.Tool,
78
"arguments": mcp.Params,
79
})
80
if err != nil {
81
return api.MCPMessage{}, fmt.Errorf("failed to marshal mcp params: %w", err)
82
}
83
84
return api.MCPMessage{
85
JSONRPC: "2.0",
86
ID: uuid.NewString(),
87
Method: "tools/call",
88
Params: rawParams,
89
}, nil
90
}
91
92
/* =========================
93
Transport interfaces
94
========================= */
95
96
type MCPTransport interface {
97
Call(endpoint string, req api.MCPMessage, headers map[string]string) (api.MCPResponse, error)
98
}
99
100
type SessionStore interface {
101
GetSessionID(endpoint string) (string, error)
102
SetSessionID(endpoint, sessionID string) error
103
DeleteSessionID(endpoint string) error
104
}
105
106
/* =========================
107
Session transport
108
========================= */
109
110
type SessionTransport struct {
111
inner MCPTransport
112
store SessionStore
113
}
114
115
func NewSessionTransport(inner MCPTransport, store SessionStore) *SessionTransport {
116
return &SessionTransport{inner: inner, store: store}
117
}
118
119
func (t *SessionTransport) Call(endpoint string, req api.MCPMessage, headers map[string]string) (api.MCPResponse, error) {
120
// Session headers are HTTP-only; stdio (and other non-http schemes) have no headers.
121
if u, err := url.Parse(endpoint); err == nil {
122
if u.Scheme != "http" && u.Scheme != "https" {
123
return t.inner.Call(endpoint, req, headers)
124
}
125
}
126
127
// Explicit session header → passthrough
128
if _, ok := headerGet(headers, "mcp-session-id"); ok {
129
return t.inner.Call(endpoint, req, headers)
130
}
131
132
// Try cached session
133
if sid, err := t.store.GetSessionID(endpoint); err == nil && strings.TrimSpace(sid) != "" {
134
h := cloneHeaders(headers)
135
h["Mcp-Session-Id"] = sid
136
137
resp, err := t.inner.Call(endpoint, req, h)
138
if err == nil {
139
t.maybeStoreSession(endpoint, resp)
140
return resp, nil
141
}
142
143
// If the server rejected the session, clear and proceed to init.
144
if looksLikeInvalidSession(err) {
145
_ = t.store.DeleteSessionID(endpoint)
146
} else {
147
return resp, err
148
}
149
}
150
151
// Initialize session
152
sid, err := t.initialize(endpoint, headers)
153
if err != nil {
154
return api.MCPResponse{}, err
155
}
156
157
h := cloneHeaders(headers)
158
h["Mcp-Session-Id"] = sid
159
160
resp, err := t.inner.Call(endpoint, req, h)
161
if err == nil {
162
t.maybeStoreSession(endpoint, resp)
163
}
164
return resp, err
165
}
166
167
func (t *SessionTransport) initialize(endpoint string, headers map[string]string) (string, error) {
168
raw, err := json.Marshal(map[string]any{
169
"protocolVersion": "2024-11-05",
170
"capabilities": map[string]any{},
171
"clientInfo": map[string]any{
172
"name": "chatgpt-cli",
173
"version": "dev",
174
},
175
})
176
if err != nil {
177
return "", err
178
}
179
180
resp, err := t.inner.Call(endpoint, api.MCPMessage{
181
JSONRPC: "2.0",
182
ID: uuid.NewString(),
183
Method: "initialize",
184
Params: raw,
185
}, headers)
186
if err != nil {
187
return "", err
188
}
189
190
sid, ok := headerGet(resp.Headers, "mcp-session-id")
191
if !ok || strings.TrimSpace(sid) == "" {
192
return "", fmt.Errorf("mcp initialize did not return session id")
193
}
194
195
_ = t.store.SetSessionID(endpoint, sid)
196
return sid, nil
197
}
198
199
func (t *SessionTransport) maybeStoreSession(endpoint string, resp api.MCPResponse) {
200
if sid, ok := headerGet(resp.Headers, "mcp-session-id"); ok && strings.TrimSpace(sid) != "" {
201
_ = t.store.SetSessionID(endpoint, sid)
202
}
203
}
204
205
func NewMCPTransport(endpoint string, caller http.Caller, headers map[string]string) (MCPTransport, error) {
206
u, err := url.Parse(endpoint)
207
if err != nil {
208
return nil, err
209
}
210
211
switch u.Scheme {
212
case "http", "https":
213
return NewMCPHTTPTransport(endpoint, caller, headers)
214
case "stdio":
215
return NewMCPStdioTransport(endpoint)
216
default:
217
return nil, fmt.Errorf("unsupported mcp transport: %s", u.Scheme)
218
}
219
}
220
221
/* =========================
222
HTTP transport
223
========================= */
224
225
type MCPHTTPTransport struct {
226
caller http.Caller
227
headers map[string]string
228
}
229
230
func NewMCPHTTPTransport(endpoint string, caller http.Caller, headers map[string]string) (*MCPHTTPTransport, error) {
231
u, err := url.Parse(endpoint)
232
if err != nil {
233
return nil, err
234
}
235
236
switch u.Scheme {
237
case "http", "https":
238
// ok
239
default:
240
return nil, fmt.Errorf("unsupported mcp http transport: %s", u.Scheme)
241
}
242
243
// Defensive copy so callers can reuse/modify their input map safely.
244
h := map[string]string{}
245
for k, v := range headers {
246
h[k] = v
247
}
248
249
return &MCPHTTPTransport{
250
caller: caller,
251
headers: h,
252
}, nil
253
}
254
255
func (t *MCPHTTPTransport) Call(endpoint string, req api.MCPMessage, extra map[string]string) (api.MCPResponse, error) {
256
body, err := json.Marshal(req)
257
if err != nil {
258
return api.MCPResponse{}, fmt.Errorf("failed to marshal mcp request: %w", err)
259
}
260
261
merged := map[string]string{}
262
for k, v := range t.headers {
263
merged[k] = v
264
}
265
for k, v := range extra {
266
merged[k] = v
267
}
268
269
httpResp, postErr := t.caller.PostWithHeadersResponse(endpoint, body, buildMCPHeaders(merged))
270
271
out := api.MCPResponse{
272
Headers: httpResp.Headers,
273
Status: httpResp.Status,
274
}
275
276
// Even on non-2xx, parse body if possible so SessionTransport can reason about it.
277
if len(httpResp.Body) > 0 {
278
var msg api.MCPMessage
279
if err := json.Unmarshal(httpResp.Body, &msg); err == nil {
280
out.Message = msg
281
} else if dataJSON, ok := extractFirstSSEDataJSON(httpResp.Body); ok {
282
if err := json.Unmarshal(dataJSON, &msg); err == nil {
283
out.Message = msg
284
}
285
}
286
}
287
288
// Prefer JSON-RPC error if present.
289
if out.Message.Error != nil {
290
return out, out.Message.Error
291
}
292
293
// Otherwise propagate HTTP-layer error.
294
if postErr != nil {
295
return out, postErr
296
}
297
298
return out, nil
299
}
300
301
/* =========================
302
STDIO transport
303
========================= */
304
305
type MCPStdioTransport struct {
306
// Endpoint is "stdio:<cmdline...>"
307
endpoint string
308
309
mu sync.Mutex
310
cmd *exec.Cmd
311
stdin io.WriteCloser
312
stdout io.ReadCloser
313
stderr io.ReadCloser
314
315
initialized bool
316
317
// response routing
318
pending map[string]chan api.MCPMessage
319
done chan struct{}
320
}
321
322
func NewMCPStdioTransport(endpoint string) (*MCPStdioTransport, error) {
323
if !strings.HasPrefix(endpoint, "stdio:") {
324
return nil, fmt.Errorf("invalid stdio endpoint: %s", endpoint)
325
}
326
return &MCPStdioTransport{endpoint: endpoint}, nil
327
}
328
329
func (t *MCPStdioTransport) Call(endpoint string, req api.MCPMessage, headers map[string]string) (api.MCPResponse, error) {
330
// headers are ignored for stdio
331
if endpoint != t.endpoint {
332
return api.MCPResponse{}, fmt.Errorf("stdio transport called with unexpected endpoint")
333
}
334
335
if err := t.ensureStarted(); err != nil {
336
return api.MCPResponse{}, err
337
}
338
if err := t.ensureInitialized(); err != nil {
339
return api.MCPResponse{}, err
340
}
341
342
if strings.TrimSpace(req.ID) == "" {
343
req.ID = uuid.NewString()
344
}
345
346
msg, err := t.roundTrip(req, 30*time.Second)
347
out := api.MCPResponse{
348
Message: msg,
349
Status: 0,
350
Headers: nil,
351
}
352
return out, err
353
}
354
355
func (t *MCPStdioTransport) ensureStarted() error {
356
t.mu.Lock()
357
defer t.mu.Unlock()
358
359
if t.cmd != nil {
360
return nil
361
}
362
363
cmdline := strings.TrimSpace(strings.TrimPrefix(t.endpoint, "stdio:"))
364
if cmdline == "" {
365
return fmt.Errorf("stdio endpoint missing command: %s", t.endpoint)
366
}
367
368
argv, err := splitCommandLine(cmdline)
369
if err != nil {
370
return err
371
}
372
if len(argv) == 0 {
373
return fmt.Errorf("stdio endpoint missing command: %s", t.endpoint)
374
}
375
376
cmd := exec.Command(argv[0], argv[1:]...) // #nosec G204
377
stdin, err := cmd.StdinPipe()
378
if err != nil {
379
return err
380
}
381
stdout, err := cmd.StdoutPipe()
382
if err != nil {
383
return err
384
}
385
stderr, err := cmd.StderrPipe()
386
if err != nil {
387
return err
388
}
389
390
if err := cmd.Start(); err != nil {
391
return fmt.Errorf("failed to start mcp stdio server: %w", err)
392
}
393
394
t.cmd = cmd
395
t.stdin = stdin
396
t.stdout = stdout
397
t.stderr = stderr
398
t.pending = map[string]chan api.MCPMessage{}
399
t.done = make(chan struct{})
400
401
go t.readLoop()
402
go t.drainStderr()
403
404
return nil
405
}
406
407
func (t *MCPStdioTransport) ensureInitialized() error {
408
t.mu.Lock()
409
if t.initialized {
410
t.mu.Unlock()
411
return nil
412
}
413
t.mu.Unlock()
414
415
initParams, _ := json.Marshal(map[string]any{
416
"protocolVersion": "2024-11-05",
417
"capabilities": map[string]any{},
418
"clientInfo": map[string]any{
419
"name": "chatgpt-cli",
420
"version": "dev",
421
},
422
})
423
424
initReq := api.MCPMessage{
425
JSONRPC: "2.0",
426
ID: uuid.NewString(),
427
Method: "initialize",
428
Params: initParams,
429
}
430
431
if _, err := t.roundTrip(initReq, 10*time.Second); err != nil {
432
return err
433
}
434
435
// notifications/initialized (no response expected)
436
notif := api.MCPMessage{
437
JSONRPC: "2.0",
438
Method: "notifications/initialized",
439
}
440
if err := t.sendOneWay(notif); err != nil {
441
return err
442
}
443
444
t.mu.Lock()
445
t.initialized = true
446
t.mu.Unlock()
447
return nil
448
}
449
450
func (t *MCPStdioTransport) sendOneWay(msg api.MCPMessage) error {
451
b, err := json.Marshal(msg)
452
if err != nil {
453
return err
454
}
455
t.mu.Lock()
456
defer t.mu.Unlock()
457
_, err = t.stdin.Write(append(b, '\n'))
458
return err
459
}
460
461
func (t *MCPStdioTransport) roundTrip(req api.MCPMessage, timeout time.Duration) (api.MCPMessage, error) {
462
ch := make(chan api.MCPMessage, 1)
463
464
t.mu.Lock()
465
t.pending[req.ID] = ch
466
467
b, err := json.Marshal(req)
468
if err == nil {
469
_, err = t.stdin.Write(append(b, '\n'))
470
}
471
t.mu.Unlock()
472
473
if err != nil {
474
t.mu.Lock()
475
delete(t.pending, req.ID)
476
t.mu.Unlock()
477
return api.MCPMessage{}, fmt.Errorf("failed to write to mcp stdio: %w", err)
478
}
479
480
select {
481
case msg, ok := <-ch:
482
if !ok {
483
return api.MCPMessage{}, fmt.Errorf("mcp stdio server closed")
484
}
485
if msg.Error != nil {
486
return msg, msg.Error
487
}
488
return msg, nil
489
case <-time.After(timeout):
490
t.mu.Lock()
491
delete(t.pending, req.ID)
492
t.mu.Unlock()
493
return api.MCPMessage{}, fmt.Errorf("mcp stdio call timed out")
494
}
495
}
496
497
func (t *MCPStdioTransport) readLoop() {
498
defer close(t.done)
499
500
scanner := bufio.NewScanner(t.stdout)
501
buf := make([]byte, 0, 64*1024)
502
scanner.Buffer(buf, 5*1024*1024)
503
504
for scanner.Scan() {
505
line := strings.TrimSpace(scanner.Text())
506
if line == "" {
507
continue
508
}
509
510
var msg api.MCPMessage
511
if err := json.Unmarshal([]byte(line), &msg); err != nil {
512
continue
513
}
514
515
// ignore notifications (no id)
516
if strings.TrimSpace(msg.ID) == "" {
517
continue
518
}
519
520
t.mu.Lock()
521
ch := t.pending[msg.ID]
522
if ch != nil {
523
delete(t.pending, msg.ID)
524
}
525
t.mu.Unlock()
526
527
if ch != nil {
528
ch <- msg
529
close(ch)
530
}
531
}
532
533
// server exited: unblock all waiters
534
t.mu.Lock()
535
for id, ch := range t.pending {
536
delete(t.pending, id)
537
close(ch)
538
}
539
t.mu.Unlock()
540
}
541
542
func (t *MCPStdioTransport) drainStderr() {
543
r := bufio.NewReader(t.stderr)
544
for {
545
_, err := r.ReadString('\n')
546
if err != nil {
547
return
548
}
549
// optionally log when debug
550
}
551
}
552
553
// Minimal shell-ish arg splitting supporting:
554
// - spaces
555
// - single quotes '...'
556
// - double quotes "..."
557
// No escapes yet (good enough for v1).
558
func splitCommandLine(s string) ([]string, error) {
559
var out []string
560
var cur strings.Builder
561
562
inSingle := false
563
inDouble := false
564
565
flush := func() {
566
if cur.Len() > 0 {
567
out = append(out, cur.String())
568
cur.Reset()
569
}
570
}
571
572
for i := 0; i < len(s); i++ {
573
ch := s[i]
574
575
switch ch {
576
case '\'':
577
if !inDouble {
578
inSingle = !inSingle
579
continue
580
}
581
cur.WriteByte(ch)
582
case '"':
583
if !inSingle {
584
inDouble = !inDouble
585
continue
586
}
587
cur.WriteByte(ch)
588
case ' ', '\t', '\n':
589
if inSingle || inDouble {
590
cur.WriteByte(ch)
591
continue
592
}
593
flush()
594
default:
595
cur.WriteByte(ch)
596
}
597
}
598
599
if inSingle || inDouble {
600
return nil, fmt.Errorf("unterminated quote in stdio command")
601
}
602
603
flush()
604
return out, nil
605
}
606
607
/* =========================
608
Helpers
609
========================= */
610
611
func looksLikeInvalidSession(err error) bool {
612
msg := strings.ToLower(err.Error())
613
if !strings.Contains(msg, "session") {
614
return false
615
}
616
617
return strings.Contains(msg, "missing") ||
618
strings.Contains(msg, "invalid") ||
619
strings.Contains(msg, "no valid") ||
620
strings.Contains(msg, "expired") ||
621
strings.Contains(msg, "unknown")
622
}
623
624
func cloneHeaders(in map[string]string) map[string]string {
625
out := map[string]string{}
626
for k, v := range in {
627
out[k] = v
628
}
629
return out
630
}
631
632
func headerGet(h map[string]string, key string) (string, bool) {
633
for k, v := range h {
634
if strings.EqualFold(k, key) {
635
return v, true
636
}
637
}
638
return "", false
639
}
640
641
func headerDelCI(h map[string]string, key string) {
642
for k := range h {
643
if strings.EqualFold(k, key) {
644
delete(h, k)
645
}
646
}
647
}
648
649
func buildMCPHeaders(in map[string]string) map[string]string {
650
h := cloneHeaders(in)
651
652
if _, ok := headerGet(h, "Content-Type"); !ok {
653
h["Content-Type"] = "application/json"
654
}
655
if _, ok := headerGet(h, "Accept"); !ok {
656
h["Accept"] = "application/json, text/event-stream"
657
}
658
659
// Canonicalize mcp-session-id → Mcp-Session-Id
660
if v, ok := headerGet(h, "mcp-session-id"); ok {
661
headerDelCI(h, "mcp-session-id")
662
h["Mcp-Session-Id"] = v
663
}
664
665
return h
666
}
667
668
func extractFirstSSEDataJSON(raw []byte) ([]byte, bool) {
669
lines := strings.Split(strings.ReplaceAll(string(raw), "\r\n", "\n"), "\n")
670
var data []string
671
for _, l := range lines {
672
if strings.HasPrefix(l, "data:") {
673
data = append(data, strings.TrimSpace(strings.TrimPrefix(l, "data:")))
674
}
675
}
676
if len(data) == 0 {
677
return nil, false
678
}
679
return []byte(strings.Join(data, "\n")), true
680
}
681
682
func formatMCPResponse(raw json.RawMessage, tool string) string {
683
if len(raw) == 0 {
684
return fmt.Sprintf("[MCP: %s] (empty result)", tool)
685
}
686
687
type contentBlock struct {
688
Type string `json:"type"`
689
Text string `json:"text,omitempty"`
690
}
691
type toolResult struct {
692
Content []contentBlock `json:"content,omitempty"`
693
}
694
695
var r toolResult
696
if err := json.Unmarshal(raw, &r); err == nil && len(r.Content) > 0 {
697
var parts []string
698
for _, b := range r.Content {
699
if strings.EqualFold(b.Type, "text") && strings.TrimSpace(b.Text) != "" {
700
parts = append(parts, normalizeMaybeJSON(b.Text))
701
}
702
}
703
if len(parts) > 0 {
704
return fmt.Sprintf("[MCP: %s]\n%s", tool, strings.Join(parts, "\n\n"))
705
}
706
}
707
708
return fmt.Sprintf("[MCP: %s]\n%s", tool, prettyJSONOrRaw(raw))
709
}
710
711
func normalizeMaybeJSON(s string) string {
712
txt := strings.TrimSpace(s)
713
if txt == "" {
714
return txt
715
}
716
717
var v any
718
if json.Unmarshal([]byte(txt), &v) == nil {
719
if b, err := json.MarshalIndent(v, "", " "); err == nil {
720
return string(b)
721
}
722
}
723
return txt
724
}
725
726
func prettyJSONOrRaw(raw []byte) string {
727
var v any
728
if err := json.Unmarshal(raw, &v); err != nil {
729
return string(raw)
730
}
731
if b, err := json.MarshalIndent(v, "", " "); err == nil {
732
return string(bytes.TrimSpace(b))
733
}
734
return string(raw)
735
}
736
737