Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/http/http.go
2649 views
1
package http
2
3
import (
4
"bufio"
5
"bytes"
6
"crypto/tls"
7
"encoding/json"
8
"fmt"
9
"github.com/kardolus/chatgpt-cli/api"
10
"github.com/kardolus/chatgpt-cli/config"
11
"go.uber.org/zap"
12
"io"
13
"net/http"
14
"os"
15
"strings"
16
)
17
18
const (
19
contentType = "application/json"
20
errFailedToRead = "failed to read response: %w"
21
errFailedToCreateRequest = "failed to create request: %w"
22
errFailedToMakeRequest = "failed to make request: %w"
23
errHTTP = "http status %d: %s"
24
errHTTPStatus = "http status: %d"
25
headerContentType = "Content-Type"
26
)
27
28
type Caller interface {
29
Post(url string, body []byte, stream bool) ([]byte, error)
30
PostWithHeaders(url string, body []byte, headers map[string]string) ([]byte, error)
31
Get(url string) ([]byte, error)
32
}
33
34
type RestCaller struct {
35
client *http.Client
36
config config.Config
37
}
38
39
// Ensure RestCaller implements Caller interface
40
var _ Caller = &RestCaller{}
41
42
func New(cfg config.Config) *RestCaller {
43
var client *http.Client
44
if cfg.SkipTLSVerify {
45
transport := &http.Transport{
46
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
47
}
48
client = &http.Client{
49
Transport: transport,
50
}
51
} else {
52
client = &http.Client{}
53
}
54
55
return &RestCaller{
56
client: client,
57
config: cfg,
58
}
59
}
60
61
type CallerFactory func(cfg config.Config) Caller
62
63
func RealCallerFactory(cfg config.Config) Caller {
64
return New(cfg)
65
}
66
67
func (r *RestCaller) Get(url string) ([]byte, error) {
68
return r.doRequest(http.MethodGet, url, nil, false)
69
}
70
71
func (r *RestCaller) Post(url string, body []byte, stream bool) ([]byte, error) {
72
return r.doRequest(http.MethodPost, url, body, stream)
73
}
74
75
func (r *RestCaller) PostWithHeaders(url string, body []byte, headers map[string]string) ([]byte, error) {
76
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
77
if err != nil {
78
return nil, fmt.Errorf(errFailedToCreateRequest, err)
79
}
80
81
// Add custom headers
82
for k, v := range headers {
83
req.Header.Set(k, v)
84
}
85
86
resp, err := r.client.Do(req)
87
if err != nil {
88
return nil, fmt.Errorf(errFailedToMakeRequest, err)
89
}
90
defer resp.Body.Close()
91
92
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
93
errorResponse, err := io.ReadAll(resp.Body)
94
if err != nil {
95
return nil, fmt.Errorf(errHTTPStatus, resp.StatusCode)
96
}
97
98
var errorData api.ErrorResponse
99
if err := json.Unmarshal(errorResponse, &errorData); err != nil {
100
return nil, fmt.Errorf(errHTTPStatus, resp.StatusCode)
101
}
102
103
return errorResponse, fmt.Errorf(errHTTP, resp.StatusCode, errorData.Error.Message)
104
}
105
106
return io.ReadAll(resp.Body)
107
}
108
109
func (r *RestCaller) ProcessResponse(reader io.Reader, writer io.Writer, endpoint string) []byte {
110
if strings.Contains(endpoint, r.config.ResponsesPath) {
111
return r.processResponsesSSE(reader, writer)
112
}
113
return r.processLegacy(reader, writer)
114
}
115
116
func (r *RestCaller) processLegacy(reader io.Reader, writer io.Writer) []byte {
117
var result []byte
118
sugar := zap.S()
119
sugar.Debugln("\nResponse\n")
120
121
scanner := bufio.NewScanner(reader)
122
for scanner.Scan() {
123
line := scanner.Text()
124
125
if zap.L().Core().Enabled(zap.DebugLevel) {
126
sugar.Debugln(line)
127
continue
128
}
129
130
if strings.HasPrefix(line, "data:") {
131
line = line[6:] // Skip the "data: " prefix
132
if len(line) < 6 {
133
continue
134
}
135
if line == "[DONE]" {
136
_, _ = writer.Write([]byte("\n"))
137
result = append(result, '\n')
138
break
139
}
140
var data api.Data
141
if err := json.Unmarshal([]byte(line), &data); err != nil {
142
_, _ = fmt.Fprintf(writer, "Error: %s\n", err.Error())
143
continue
144
}
145
for _, choice := range data.Choices {
146
if content, ok := choice.Delta["content"].(string); ok {
147
_, _ = writer.Write([]byte(content))
148
result = append(result, content...)
149
}
150
}
151
}
152
}
153
return result
154
}
155
156
func (r *RestCaller) processResponsesSSE(reader io.Reader, writer io.Writer) []byte {
157
var (
158
result []byte
159
curEvent string
160
done bool
161
sugar = zap.S()
162
)
163
164
sugar.Debugln("\nResponse\n")
165
166
scanner := bufio.NewScanner(reader)
167
buf := make([]byte, 0, 64*1024)
168
scanner.Buffer(buf, 1024*1024)
169
170
for scanner.Scan() {
171
line := scanner.Text()
172
173
if zap.L().Core().Enabled(zap.DebugLevel) {
174
sugar.Debugln(line)
175
continue
176
}
177
178
if strings.HasPrefix(line, ":") {
179
continue
180
}
181
182
switch {
183
case strings.HasPrefix(line, "event:"):
184
curEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
185
continue
186
187
case strings.HasPrefix(line, "data:"):
188
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
189
190
if curEvent == "" {
191
if payload == "[DONE]" {
192
_, _ = writer.Write([]byte("\n"))
193
result = append(result, '\n')
194
done = true
195
break
196
}
197
var legacy struct {
198
Choices []struct {
199
Delta map[string]any `json:"delta"`
200
} `json:"choices"`
201
}
202
if err := json.Unmarshal([]byte(payload), &legacy); err != nil {
203
_, _ = fmt.Fprintf(writer, "Error: %s\n", err.Error())
204
continue
205
}
206
for _, ch := range legacy.Choices {
207
if s, ok := ch.Delta["content"].(string); ok && s != "" {
208
_, _ = writer.Write([]byte(s))
209
result = append(result, s...)
210
}
211
}
212
continue
213
}
214
215
var env struct {
216
Type string `json:"type"`
217
Delta string `json:"delta"` // response.output_text.delta
218
Text string `json:"text"` // response.output_text.done/content_part.done (optional)
219
Response struct {
220
Status string `json:"status"`
221
} `json:"response"`
222
}
223
if err := json.Unmarshal([]byte(payload), &env); err != nil {
224
_, _ = fmt.Fprintf(writer, "Error: %s\n", err.Error())
225
continue
226
}
227
228
switch env.Type {
229
case "response.output_text.delta":
230
if env.Delta != "" {
231
_, _ = writer.Write([]byte(env.Delta))
232
result = append(result, env.Delta...)
233
}
234
case "response.completed":
235
if len(result) == 0 || !bytes.HasSuffix(result, []byte("\n")) {
236
_, _ = writer.Write([]byte("\n"))
237
result = append(result, '\n')
238
}
239
done = true
240
default:
241
// ignore other SSE types
242
}
243
}
244
245
if done {
246
break
247
}
248
}
249
return result
250
}
251
func (r *RestCaller) doRequest(method, url string, body []byte, stream bool) ([]byte, error) {
252
req, err := r.newRequest(method, url, body)
253
if err != nil {
254
return nil, fmt.Errorf(errFailedToCreateRequest, err)
255
}
256
257
response, err := r.client.Do(req)
258
if err != nil {
259
return nil, fmt.Errorf(errFailedToMakeRequest, err)
260
}
261
defer response.Body.Close()
262
263
if response.StatusCode < 200 || response.StatusCode >= 300 {
264
errorResponse, err := io.ReadAll(response.Body)
265
if err != nil {
266
return nil, fmt.Errorf(errHTTPStatus, response.StatusCode)
267
}
268
269
var errorData api.ErrorResponse
270
if err := json.Unmarshal(errorResponse, &errorData); err != nil {
271
return nil, fmt.Errorf(errHTTPStatus, response.StatusCode)
272
}
273
274
return errorResponse, fmt.Errorf(errHTTP, response.StatusCode, errorData.Error.Message)
275
}
276
277
if stream {
278
return r.ProcessResponse(response.Body, os.Stdout, url), nil
279
}
280
281
result, err := io.ReadAll(response.Body)
282
if err != nil {
283
return nil, fmt.Errorf(errFailedToRead, err)
284
}
285
286
return result, nil
287
}
288
289
func (r *RestCaller) newRequest(method, url string, body []byte) (*http.Request, error) {
290
req, err := http.NewRequest(method, url, bytes.NewBuffer(body))
291
if err != nil {
292
return nil, err
293
}
294
295
if r.config.APIKey != "" {
296
req.Header.Set(r.config.AuthHeader, r.config.AuthTokenPrefix+r.config.APIKey)
297
}
298
req.Header.Set(headerContentType, contentType)
299
300
return req, nil
301
}
302
303