Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/media.go
3431 views
1
package client
2
3
import (
4
"bytes"
5
"context"
6
"encoding/base64"
7
"encoding/json"
8
"fmt"
9
"github.com/kardolus/chatgpt-cli/api"
10
"github.com/kardolus/chatgpt-cli/history"
11
"github.com/kardolus/chatgpt-cli/internal"
12
"io"
13
"mime/multipart"
14
stdhttp "net/http"
15
"net/textproto"
16
"net/url"
17
"path/filepath"
18
"strings"
19
)
20
21
const (
22
audioType = "input_audio"
23
imageContent = "data:%s;base64,%s"
24
imageURLType = "image_url"
25
httpScheme = "http"
26
httpsScheme = "https"
27
)
28
29
// EditImage edits an input image using a text prompt and writes the modified image to the specified output path.
30
//
31
// This method sends a multipart/form-data POST request to the image editing endpoint
32
// (typically OpenAI's /v1/images/edits). The request includes:
33
// - The image file to edit.
34
// - A text prompt describing how the image should be modified.
35
// - The model ID (e.g., gpt-image-1).
36
//
37
// The response is expected to contain a base64-encoded image, which is decoded and written to the outputPath.
38
//
39
// Parameters:
40
// - inputText: A text prompt describing the desired modifications to the image.
41
// - inputPath: The file path to the source image (must be a supported format: PNG, JPEG, or WebP).
42
// - outputPath: The file path where the edited image will be saved.
43
//
44
// Returns:
45
// - An error if any step of the process fails: reading the file, building the request, sending it,
46
// decoding the response, or writing the output image.
47
//
48
// Example:
49
//
50
// err := client.EditImage("Add a rainbow in the sky", "input.png", "output.png")
51
// if err != nil {
52
// log.Fatal(err)
53
// }
54
func (c *Client) EditImage(inputText, inputPath, outputPath string) error {
55
endpoint := c.getEndpoint(c.Config.ImageEditsPath)
56
57
file, err := c.reader.Open(inputPath)
58
if err != nil {
59
return fmt.Errorf("failed to open input image: %w", err)
60
}
61
defer file.Close()
62
63
var buf bytes.Buffer
64
writer := multipart.NewWriter(&buf)
65
66
mimeType, err := c.getMimeTypeFromFileContent(inputPath)
67
if err != nil {
68
return fmt.Errorf("failed to detect MIME type: %w", err)
69
}
70
if !strings.HasPrefix(mimeType, "image/") {
71
return fmt.Errorf("unsupported MIME type: %s", mimeType)
72
}
73
74
header := make(textproto.MIMEHeader)
75
header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="image"; filename="%s"`, filepath.Base(inputPath)))
76
header.Set("Content-Type", mimeType)
77
78
part, err := writer.CreatePart(header)
79
if err != nil {
80
return fmt.Errorf("failed to create image part: %w", err)
81
}
82
if _, err := io.Copy(part, file); err != nil {
83
return fmt.Errorf("failed to copy image data: %w", err)
84
}
85
86
if err := writer.WriteField("prompt", inputText); err != nil {
87
return fmt.Errorf("failed to add prompt: %w", err)
88
}
89
if err := writer.WriteField("model", c.Config.Model); err != nil {
90
return fmt.Errorf("failed to add model: %w", err)
91
}
92
93
if err := writer.Close(); err != nil {
94
return fmt.Errorf("failed to close multipart writer: %w", err)
95
}
96
97
c.printRequestDebugInfo(endpoint, buf.Bytes(), map[string]string{
98
"Content-Type": writer.FormDataContentType(),
99
})
100
101
respBytes, err := c.Caller.PostWithHeaders(endpoint, buf.Bytes(), map[string]string{
102
c.Config.AuthHeader: fmt.Sprintf("%s %s", c.Config.AuthTokenPrefix, c.Config.APIKey),
103
internal.HeaderContentTypeKey: writer.FormDataContentType(),
104
})
105
if err != nil {
106
return fmt.Errorf("failed to edit image: %w", err)
107
}
108
109
// Parse the JSON and extract b64_json
110
var response struct {
111
Data []struct {
112
B64 string `json:"b64_json"`
113
} `json:"data"`
114
}
115
if err := json.Unmarshal(respBytes, &response); err != nil {
116
return fmt.Errorf("failed to decode response: %w", err)
117
}
118
if len(response.Data) == 0 {
119
return fmt.Errorf("no image data returned")
120
}
121
122
imgBytes, err := base64.StdEncoding.DecodeString(response.Data[0].B64)
123
if err != nil {
124
return fmt.Errorf("failed to decode base64 image: %w", err)
125
}
126
127
outFile, err := c.writer.Create(outputPath)
128
if err != nil {
129
return fmt.Errorf("failed to create output file: %w", err)
130
}
131
defer outFile.Close()
132
133
if err := c.writer.Write(outFile, imgBytes); err != nil {
134
return fmt.Errorf("failed to write image: %w", err)
135
}
136
137
c.printResponseDebugInfo([]byte(fmt.Sprintf("[image] %d bytes written to %s", len(imgBytes), outputPath)))
138
return nil
139
}
140
141
// GenerateImage sends a prompt to the configured image generation model (e.g., gpt-image-1)
142
// and writes the resulting image to the specified output path.
143
//
144
// The method performs the following steps:
145
// 1. Sends a POST request to the image generation endpoint with the provided prompt.
146
// 2. Parses the response and extracts the base64-encoded image data.
147
// 3. Decodes the image bytes and writes them to the given outputPath.
148
// 4. Logs the number of bytes written using debug output.
149
//
150
// Parameters:
151
// - inputText: The prompt describing the image to be generated.
152
// - outputPath: The file path where the generated image (e.g., .png) will be saved.
153
//
154
// Returns:
155
// - An error if any part of the request, decoding, or file writing fails.
156
func (c *Client) GenerateImage(inputText, outputPath string) error {
157
req := api.Draw{
158
Model: c.Config.Model,
159
Prompt: inputText,
160
}
161
162
return c.postAndWriteBinaryOutput(
163
c.getEndpoint(c.Config.ImageGenerationsPath),
164
req,
165
outputPath,
166
"image",
167
func(respBytes []byte) ([]byte, error) {
168
var response struct {
169
Data []struct {
170
B64 string `json:"b64_json"`
171
} `json:"data"`
172
}
173
if err := json.Unmarshal(respBytes, &response); err != nil {
174
return nil, fmt.Errorf("failed to decode response: %w", err)
175
}
176
if len(response.Data) == 0 {
177
return nil, fmt.Errorf("no image data returned")
178
}
179
decoded, err := base64.StdEncoding.DecodeString(response.Data[0].B64)
180
if err != nil {
181
return nil, fmt.Errorf("failed to decode base64 image: %w", err)
182
}
183
return decoded, nil
184
},
185
)
186
}
187
188
// SynthesizeSpeech converts the given input text into speech using the configured TTS model,
189
// and writes the resulting audio to the specified output file.
190
//
191
// The audio format is inferred from the output file's extension (e.g., "mp3", "wav") and sent
192
// as the "response_format" in the request to the OpenAI speech synthesis endpoint.
193
//
194
// Parameters:
195
// - inputText: The text to synthesize into speech.
196
// - outputPath: The path to the output audio file. The file extension determines the response format.
197
//
198
// Returns an error if the request fails, the response cannot be written, or the file cannot be created.
199
func (c *Client) SynthesizeSpeech(inputText, outputPath string) error {
200
req := api.Speech{
201
Model: c.Config.Model,
202
Voice: c.Config.Voice,
203
Input: inputText,
204
ResponseFormat: getExtension(outputPath),
205
}
206
return c.postAndWriteBinaryOutput(c.getEndpoint(c.Config.SpeechPath), req, outputPath, "binary", nil)
207
}
208
209
// Transcribe uploads an audio file to the OpenAI transcription endpoint and returns the transcribed text.
210
//
211
// It reads the audio file from the provided `audioPath`, creates a multipart/form-data request with the model name
212
// and the audio file, and sends it to the endpoint defined by the `TranscriptionsPath` in the client config.
213
// The method expects a JSON response containing a "text" field with the transcription result.
214
//
215
// Parameters:
216
// - audioPath: The local file path to the audio file to be transcribed.
217
//
218
// Returns:
219
// - string: The transcribed text from the audio file.
220
// - error: An error if the file can't be read, the request fails, or the response is invalid.
221
//
222
// This method supports formats like mp3, mp4, mpeg, mpga, m4a, wav, and webm, depending on API compatibility.
223
func (c *Client) Transcribe(audioPath string) (string, error) {
224
c.initHistory()
225
226
file, err := c.reader.Open(audioPath)
227
if err != nil {
228
return "", fmt.Errorf("failed to open audio file: %w", err)
229
}
230
defer file.Close()
231
232
var buf bytes.Buffer
233
writer := multipart.NewWriter(&buf)
234
235
_ = writer.WriteField("model", c.Config.Model)
236
237
part, err := writer.CreateFormFile("file", filepath.Base(audioPath))
238
if err != nil {
239
return "", err
240
}
241
if _, err := io.Copy(part, file); err != nil {
242
return "", err
243
}
244
245
if err := writer.Close(); err != nil {
246
return "", err
247
}
248
249
endpoint := c.getEndpoint(c.Config.TranscriptionsPath)
250
headers := map[string]string{
251
internal.HeaderContentTypeKey: writer.FormDataContentType(),
252
c.Config.AuthHeader: fmt.Sprintf("%s %s", c.Config.AuthTokenPrefix, c.Config.APIKey),
253
}
254
255
c.printRequestDebugInfo(endpoint, buf.Bytes(), headers)
256
257
raw, err := c.Caller.PostWithHeaders(endpoint, buf.Bytes(), headers)
258
if err != nil {
259
return "", err
260
}
261
262
c.printResponseDebugInfo(raw)
263
264
var res struct {
265
Text string `json:"text"`
266
}
267
if err := json.Unmarshal(raw, &res); err != nil {
268
return "", fmt.Errorf("failed to parse transcription: %w", err)
269
}
270
271
c.History = append(c.History, history.History{
272
Message: api.Message{
273
Role: UserRole,
274
Content: fmt.Sprintf("[transcribe] %s", filepath.Base(audioPath)),
275
},
276
Timestamp: c.timer.Now(),
277
})
278
279
c.History = append(c.History, history.History{
280
Message: api.Message{
281
Role: AssistantRole,
282
Content: res.Text,
283
},
284
Timestamp: c.timer.Now(),
285
})
286
287
c.truncateHistory()
288
289
if !c.Config.OmitHistory {
290
_ = c.historyStore.Write(c.History)
291
}
292
293
return res.Text, nil
294
}
295
296
func (c *Client) appendMediaMessages(ctx context.Context, messages []api.Message) ([]api.Message, error) {
297
if data, ok := ctx.Value(internal.BinaryDataKey).([]byte); ok {
298
content, err := c.createImageContentFromBinary(data)
299
if err != nil {
300
return nil, err
301
}
302
messages = append(messages, api.Message{
303
Role: UserRole,
304
Content: []api.ImageContent{content},
305
})
306
} else if path, ok := ctx.Value(internal.ImagePathKey).(string); ok {
307
content, err := c.createImageContentFromURLOrFile(path)
308
if err != nil {
309
return nil, err
310
}
311
messages = append(messages, api.Message{
312
Role: UserRole,
313
Content: []api.ImageContent{content},
314
})
315
} else if path, ok := ctx.Value(internal.AudioPathKey).(string); ok {
316
content, err := c.createAudioContentFromFile(path)
317
if err != nil {
318
return nil, err
319
}
320
messages = append(messages, api.Message{
321
Role: UserRole,
322
Content: []api.AudioContent{content},
323
})
324
}
325
return messages, nil
326
}
327
328
func (c *Client) base64Encode(path string) (string, error) {
329
imageData, err := c.reader.ReadFile(path)
330
if err != nil {
331
return "", err
332
}
333
334
return base64.StdEncoding.EncodeToString(imageData), nil
335
}
336
337
func (c *Client) createAudioContentFromFile(audio string) (api.AudioContent, error) {
338
339
format, err := c.detectAudioFormat(audio)
340
if err != nil {
341
return api.AudioContent{}, err
342
}
343
344
encodedAudio, err := c.base64Encode(audio)
345
if err != nil {
346
return api.AudioContent{}, err
347
}
348
349
return api.AudioContent{
350
Type: audioType,
351
InputAudio: api.InputAudio{
352
Data: encodedAudio,
353
Format: format,
354
},
355
}, nil
356
}
357
358
func (c *Client) createImageContentFromBinary(binary []byte) (api.ImageContent, error) {
359
mime, err := getMimeTypeFromBytes(binary)
360
if err != nil {
361
return api.ImageContent{}, err
362
}
363
364
encoded := base64.StdEncoding.EncodeToString(binary)
365
content := api.ImageContent{
366
Type: imageURLType,
367
ImageURL: struct {
368
URL string `json:"url"`
369
}{
370
URL: fmt.Sprintf(imageContent, mime, encoded),
371
},
372
}
373
374
return content, nil
375
}
376
377
func (c *Client) createImageContentFromURLOrFile(image string) (api.ImageContent, error) {
378
var content api.ImageContent
379
380
if isValidURL(image) {
381
content = api.ImageContent{
382
Type: imageURLType,
383
ImageURL: struct {
384
URL string `json:"url"`
385
}{
386
URL: image,
387
},
388
}
389
} else {
390
mime, err := c.getMimeTypeFromFileContent(image)
391
if err != nil {
392
return content, err
393
}
394
395
encodedImage, err := c.base64Encode(image)
396
if err != nil {
397
return content, err
398
}
399
400
content = api.ImageContent{
401
Type: imageURLType,
402
ImageURL: struct {
403
URL string `json:"url"`
404
}{
405
URL: fmt.Sprintf(imageContent, mime, encodedImage),
406
},
407
}
408
}
409
410
return content, nil
411
}
412
413
func (c *Client) detectAudioFormat(path string) (string, error) {
414
file, err := c.reader.Open(path)
415
if err != nil {
416
return "", err
417
}
418
defer file.Close()
419
420
buf, err := c.reader.ReadBufferFromFile(file)
421
if err != nil {
422
return "", err
423
}
424
425
// WAV
426
if string(buf[0:4]) == "RIFF" && string(buf[8:12]) == "WAVE" {
427
return "wav", nil
428
}
429
430
// MP3 (ID3 or sync bits)
431
if string(buf[0:3]) == "ID3" || (buf[0] == 0xFF && (buf[1]&0xE0) == 0xE0) {
432
return "mp3", nil
433
}
434
435
// FLAC
436
if string(buf[0:4]) == "fLaC" {
437
return "flac", nil
438
}
439
440
// OGG
441
if string(buf[0:4]) == "OggS" {
442
return "ogg", nil
443
}
444
445
// M4A / MP4
446
if string(buf[4:8]) == "ftyp" {
447
if string(buf[8:12]) == "M4A " || string(buf[8:12]) == "isom" || string(buf[8:12]) == "mp42" {
448
return "m4a", nil
449
}
450
return "mp4", nil
451
}
452
453
return "unknown", nil
454
}
455
456
func (c *Client) getMimeTypeFromFileContent(path string) (string, error) {
457
file, err := c.reader.Open(path)
458
if err != nil {
459
return "", err
460
}
461
defer file.Close()
462
463
buffer, err := c.reader.ReadBufferFromFile(file)
464
if err != nil {
465
return "", err
466
}
467
468
mimeType := stdhttp.DetectContentType(buffer)
469
470
return mimeType, nil
471
}
472
473
func (c *Client) postAndWriteBinaryOutput(endpoint string, requestBody interface{}, outputPath, debugLabel string, transform func([]byte) ([]byte, error)) error {
474
body, err := json.Marshal(requestBody)
475
if err != nil {
476
return fmt.Errorf("failed to marshal request: %w", err)
477
}
478
479
c.printRequestDebugInfo(endpoint, body, nil)
480
481
respBytes, err := c.Caller.Post(endpoint, body, false)
482
if err != nil {
483
return fmt.Errorf("API request failed: %w", err)
484
}
485
486
if transform != nil {
487
respBytes, err = transform(respBytes)
488
if err != nil {
489
return err
490
}
491
}
492
493
outFile, err := c.writer.Create(outputPath)
494
if err != nil {
495
return fmt.Errorf("failed to create output file: %w", err)
496
}
497
defer outFile.Close()
498
499
if err := c.writer.Write(outFile, respBytes); err != nil {
500
return fmt.Errorf("failed to write %s: %w", debugLabel, err)
501
}
502
503
c.printResponseDebugInfo([]byte(fmt.Sprintf("[%s] %d bytes written to %s", debugLabel, len(respBytes), outputPath)))
504
return nil
505
}
506
507
func getExtension(path string) string {
508
ext := filepath.Ext(path) // e.g. ".mp4"
509
if ext != "" {
510
return strings.TrimPrefix(ext, ".") // "mp4"
511
}
512
return ""
513
}
514
515
func getMimeTypeFromBytes(data []byte) (string, error) {
516
mimeType := stdhttp.DetectContentType(data)
517
518
return mimeType, nil
519
}
520
521
func isValidURL(input string) bool {
522
parsedURL, err := url.ParseRequestURI(input)
523
if err != nil {
524
return false
525
}
526
527
// Ensure that the URL has a valid scheme
528
schemes := []string{httpScheme, httpsScheme}
529
for _, scheme := range schemes {
530
if strings.HasPrefix(parsedURL.Scheme, scheme) {
531
return true
532
}
533
}
534
535
return false
536
}
537
538