Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/http/http.go
3434 views
1
package http
2
3
import (
4
"bufio"
5
"bytes"
6
"crypto/tls"
7
"encoding/json"
8
"fmt"
9
"io"
10
"net/http"
11
"os"
12
"strings"
13
"time"
14
15
"github.com/kardolus/chatgpt-cli/api"
16
"github.com/kardolus/chatgpt-cli/config"
17
"github.com/kardolus/chatgpt-cli/internal"
18
"go.uber.org/zap"
19
)
20
21
const (
22
errFailedToRead = "failed to read response: %w"
23
errFailedToCreateRequest = "failed to create request: %w"
24
errFailedToMakeRequest = "failed to make request: %w"
25
errHTTP = "http status %d: %s"
26
errHTTPStatus = "http status: %d"
27
defaultHTTPTimeout = 60 * time.Second
28
)
29
30
type Caller interface {
31
Post(url string, body []byte, stream bool) ([]byte, error)
32
PostWithHeaders(url string, body []byte, headers map[string]string) ([]byte, error)
33
Get(url string) ([]byte, error)
34
PostWithHeadersResponse(url string, body []byte, headers map[string]string) (api.HTTPResponse, error)
35
}
36
37
type RestCaller struct {
38
client *http.Client
39
config config.Config
40
}
41
42
// Ensure RestCaller implements Caller interface
43
var _ Caller = &RestCaller{}
44
45
func New(cfg config.Config) *RestCaller {
46
client := &http.Client{Timeout: defaultHTTPTimeout}
47
48
if cfg.SkipTLSVerify {
49
transport := http.DefaultTransport.(*http.Transport).Clone()
50
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
51
client.Transport = transport
52
}
53
54
return &RestCaller{
55
client: client,
56
config: cfg,
57
}
58
}
59
60
type CallerFactory func(cfg config.Config) Caller
61
62
func RealCallerFactory(cfg config.Config) Caller {
63
return New(cfg)
64
}
65
66
func (r *RestCaller) Get(url string) ([]byte, error) {
67
return r.doRequest(http.MethodGet, url, nil, false)
68
}
69
70
func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error) {
71
return r.doRequest(http.MethodPost, url, body, stream)
72
}
73
74
func (r *RestCaller) PostWithHeaders(url string, body []byte, headers map[string]string) ([]byte, error) {
75
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
76
if err != nil {
77
return nil, fmt.Errorf(errFailedToCreateRequest, err)
78
}
79
80
// Add custom headers
81
for k, v := range headers {
82
req.Header.Set(k, v)
83
}
84
85
resp, err := r.client.Do(req)
86
if err != nil {
87
return nil, fmt.Errorf(errFailedToMakeRequest, err)
88
}
89
defer resp.Body.Close()
90
91
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
92
errorResponse, err := io.ReadAll(resp.Body)
93
if err != nil {
94
return nil, fmt.Errorf(errHTTPStatus, resp.StatusCode)
95
}
96
97
var errorData api.ErrorResponse
98
if err := json.Unmarshal(errorResponse, &errorData); err != nil {
99
return nil, fmt.Errorf(errHTTPStatus, resp.StatusCode)
100
}
101
102
return errorResponse, fmt.Errorf(errHTTP, resp.StatusCode, errorData.Error.Message)
103
}
104
105
return io.ReadAll(resp.Body)
106
}
107
108
func (r *RestCaller) PostWithHeadersResponse(url string, body []byte, headers map[string]string) (api.HTTPResponse, error) {
109
// tests construct RestCaller{} (nil client) — avoid panic
110
if r.client == nil {
111
r.client = http.DefaultClient
112
}
113
114
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
115
if err != nil {
116
return api.HTTPResponse{}, fmt.Errorf(errFailedToCreateRequest, err)
117
}
118
119
for k, v := range headers {
120
req.Header.Set(k, v)
121
}
122
123
resp, err := r.client.Do(req)
124
if err != nil {
125
return api.HTTPResponse{}, fmt.Errorf(errFailedToMakeRequest, err)
126
}
127
defer resp.Body.Close()
128
129
respBody, readErr := io.ReadAll(resp.Body)
130
if readErr != nil {
131
return api.HTTPResponse{}, readErr
132
}
133
134
outHeaders := map[string]string{}
135
for k, vals := range resp.Header {
136
if len(vals) == 0 {
137
continue
138
}
139
outHeaders[k] = strings.Join(vals, ", ")
140
}
141
142
out := api.HTTPResponse{
143
Status: resp.StatusCode,
144
Headers: outHeaders,
145
Body: respBody,
146
}
147
148
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
149
// Try OpenAI error shape first
150
var errorData api.ErrorResponse
151
if err := json.Unmarshal(respBody, &errorData); err == nil && errorData.Error.Message != "" {
152
return out, fmt.Errorf(errHTTP, resp.StatusCode, errorData.Error.Message)
153
}
154
155
// Otherwise include raw body so you can debug MCP server errors
156
msg := strings.TrimSpace(string(respBody))
157
if msg == "" {
158
return out, fmt.Errorf(errHTTPStatus, resp.StatusCode)
159
}
160
return out, fmt.Errorf("http status %d: %s", resp.StatusCode, msg)
161
}
162
163
return out, nil
164
}
165
166
func (r *RestCaller) ProcessResponse(reader io.Reader, writer io.Writer, endpoint string) []byte {
167
if strings.Contains(endpoint, r.config.ResponsesPath) {
168
return r.processResponsesSSE(reader, writer)
169
}
170
return r.processLegacy(reader, writer)
171
}
172
173
func (r *RestCaller) processLegacy(reader io.Reader, writer io.Writer) []byte {
174
var result []byte
175
sugar := zap.S()
176
sugar.Debugln("\nResponse\n")
177
178
scanner := bufio.NewScanner(reader)
179
for scanner.Scan() {
180
line := scanner.Text()
181
182
if zap.L().Core().Enabled(zap.DebugLevel) {
183
sugar.Debugln(line)
184
continue
185
}
186
187
if strings.HasPrefix(line, "data:") {
188
line = line[6:] // Skip the "data: " prefix
189
if len(line) < 6 {
190
continue
191
}
192
if line == "[DONE]" {
193
_, _ = writer.Write([]byte("\n"))
194
result = append(result, '\n')
195
break
196
}
197
var data api.Data
198
if err := json.Unmarshal([]byte(line), &data); err != nil {
199
_, _ = fmt.Fprintf(writer, "Error: %s\n", err.Error())
200
continue
201
}
202
for _, choice := range data.Choices {
203
if content, ok := choice.Delta["content"].(string); ok {
204
_, _ = writer.Write([]byte(content))
205
result = append(result, content...)
206
}
207
}
208
}
209
}
210
return result
211
}
212
213
func (r *RestCaller) processResponsesSSE(reader io.Reader, writer io.Writer) []byte {
214
var (
215
result []byte
216
curEvent string
217
done bool
218
sugar = zap.S()
219
)
220
221
sugar.Debugln("\nResponse\n")
222
223
scanner := bufio.NewScanner(reader)
224
buf := make([]byte, 0, 64*1024)
225
scanner.Buffer(buf, 1024*1024)
226
227
for scanner.Scan() {
228
line := scanner.Text()
229
230
if zap.L().Core().Enabled(zap.DebugLevel) {
231
sugar.Debugln(line)
232
continue
233
}
234
235
if strings.HasPrefix(line, ":") {
236
continue
237
}
238
239
switch {
240
case strings.HasPrefix(line, "event:"):
241
curEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
242
continue
243
244
case strings.HasPrefix(line, "data:"):
245
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
246
247
if curEvent == "" {
248
if payload == "[DONE]" {
249
_, _ = writer.Write([]byte("\n"))
250
result = append(result, '\n')
251
done = true
252
break
253
}
254
var legacy struct {
255
Choices []struct {
256
Delta map[string]any `json:"delta"`
257
} `json:"choices"`
258
}
259
if err := json.Unmarshal([]byte(payload), &legacy); err != nil {
260
_, _ = fmt.Fprintf(writer, "Error: %s\n", err.Error())
261
continue
262
}
263
for _, ch := range legacy.Choices {
264
if s, ok := ch.Delta["content"].(string); ok && s != "" {
265
_, _ = writer.Write([]byte(s))
266
result = append(result, s...)
267
}
268
}
269
continue
270
}
271
272
var env struct {
273
Type string `json:"type"`
274
Delta string `json:"delta"` // response.output_text.delta
275
Text string `json:"text"` // response.output_text.done/content_part.done (optional)
276
Response struct {
277
Status string `json:"status"`
278
} `json:"response"`
279
}
280
if err := json.Unmarshal([]byte(payload), &env); err != nil {
281
_, _ = fmt.Fprintf(writer, "Error: %s\n", err.Error())
282
continue
283
}
284
285
switch env.Type {
286
case "response.output_text.delta":
287
if env.Delta != "" {
288
_, _ = writer.Write([]byte(env.Delta))
289
result = append(result, env.Delta...)
290
}
291
case "response.completed":
292
if len(result) == 0 || !bytes.HasSuffix(result, []byte("\n")) {
293
_, _ = writer.Write([]byte("\n"))
294
result = append(result, '\n')
295
}
296
done = true
297
default:
298
// ignore other SSE types
299
}
300
}
301
302
if done {
303
break
304
}
305
}
306
return result
307
}
308
309
func (r *RestCaller) doRequest(method, url string, body []byte, stream bool) ([]byte, error) {
310
req, err := r.newRequest(method, url, body)
311
if err != nil {
312
return nil, fmt.Errorf(errFailedToCreateRequest, err)
313
}
314
315
response, err := r.client.Do(req)
316
if err != nil {
317
return nil, fmt.Errorf(errFailedToMakeRequest, err)
318
}
319
defer response.Body.Close()
320
321
if response.StatusCode < 200 || response.StatusCode >= 300 {
322
errorResponse, err := io.ReadAll(response.Body)
323
if err != nil {
324
return nil, fmt.Errorf(errHTTPStatus, response.StatusCode)
325
}
326
327
var errorData api.ErrorResponse
328
if err := json.Unmarshal(errorResponse, &errorData); err != nil {
329
return nil, fmt.Errorf(errHTTPStatus, response.StatusCode)
330
}
331
332
return errorResponse, fmt.Errorf(errHTTP, response.StatusCode, errorData.Error.Message)
333
}
334
335
if stream {
336
return r.ProcessResponse(response.Body, os.Stdout, url), nil
337
}
338
339
result, err := io.ReadAll(response.Body)
340
if err != nil {
341
return nil, fmt.Errorf(errFailedToRead, err)
342
}
343
344
return result, nil
345
}
346
347
func (r *RestCaller) newRequest(method, url string, body []byte) (*http.Request, error) {
348
req, err := http.NewRequest(method, url, bytes.NewBuffer(body))
349
if err != nil {
350
return nil, err
351
}
352
353
if r.config.APIKey != "" {
354
req.Header.Set(r.config.AuthHeader, r.config.AuthTokenPrefix+r.config.APIKey)
355
}
356
req.Header.Set(internal.HeaderContentTypeKey, internal.HeaderContentTypeValue)
357
req.Header.Set(internal.HeaderUserAgentKey, r.config.UserAgent)
358
359
for key, value := range r.config.CustomHeaders {
360
req.Header.Set(key, value)
361
}
362
363
return req, nil
364
}
365
366