Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/media_test.go
3431 views
1
// media_test.go
2
package client_test
3
4
import (
5
"bytes"
6
"encoding/base64"
7
"encoding/json"
8
"errors"
9
"fmt"
10
"io"
11
"os"
12
"testing"
13
"time"
14
15
"github.com/golang/mock/gomock"
16
"github.com/kardolus/chatgpt-cli/api"
17
"github.com/kardolus/chatgpt-cli/api/client"
18
"github.com/kardolus/chatgpt-cli/history"
19
20
. "github.com/onsi/gomega"
21
"github.com/sclevine/spec"
22
)
23
24
func testMedia(t *testing.T, when spec.G, it spec.S) {
25
when("Media()", func() {
26
when("SynthesizeSpeech()", func() {
27
const (
28
inputText = "mock-input"
29
outputFile = "mock-output"
30
outputFileType = "mp3"
31
errorText = "mock error occurred"
32
)
33
34
var (
35
subject *client.Client
36
fileName = outputFile + "." + outputFileType
37
body []byte
38
response []byte
39
)
40
41
it.Before(func() {
42
subject = factory.buildClientWithoutConfig()
43
request := api.Speech{
44
Model: subject.Config.Model,
45
Voice: subject.Config.Voice,
46
Input: inputText,
47
ResponseFormat: outputFileType,
48
}
49
var err error
50
body, err = json.Marshal(request)
51
Expect(err).NotTo(HaveOccurred())
52
53
response = []byte("mock response")
54
})
55
56
it("throws an error when the http call fails", func() {
57
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).
58
Return(nil, errors.New(errorText))
59
60
err := subject.SynthesizeSpeech(inputText, fileName)
61
Expect(err).To(HaveOccurred())
62
Expect(err.Error()).To(ContainSubstring(errorText))
63
})
64
65
it("throws an error when a file cannot be created", func() {
66
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).
67
Return(response, nil)
68
mockWriter.EXPECT().Create(fileName).Return(nil, errors.New(errorText))
69
70
err := subject.SynthesizeSpeech(inputText, fileName)
71
Expect(err).To(HaveOccurred())
72
Expect(err.Error()).To(ContainSubstring(errorText))
73
})
74
75
it("throws an error when bytes cannot be written to the output file", func() {
76
file, err := os.Open(os.DevNull)
77
Expect(err).NotTo(HaveOccurred())
78
defer file.Close()
79
80
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).
81
Return(response, nil)
82
mockWriter.EXPECT().Create(fileName).Return(file, nil)
83
mockWriter.EXPECT().Write(file, response).Return(errors.New(errorText))
84
85
err = subject.SynthesizeSpeech(inputText, fileName)
86
Expect(err).To(HaveOccurred())
87
Expect(err.Error()).To(ContainSubstring(errorText))
88
})
89
90
it("succeeds when no errors occurred", func() {
91
file, err := os.Open(os.DevNull)
92
Expect(err).NotTo(HaveOccurred())
93
defer file.Close()
94
95
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.SpeechPath, body, false).
96
Return(response, nil)
97
mockWriter.EXPECT().Create(fileName).Return(file, nil)
98
mockWriter.EXPECT().Write(file, response).Return(nil)
99
100
err = subject.SynthesizeSpeech(inputText, fileName)
101
Expect(err).NotTo(HaveOccurred())
102
})
103
})
104
105
when("GenerateImage()", func() {
106
const (
107
inputText = "draw a happy dog"
108
outputFile = "dog.png"
109
errorText = "mock error occurred"
110
)
111
112
var (
113
subject *client.Client
114
body []byte
115
)
116
117
it.Before(func() {
118
subject = factory.buildClientWithoutConfig()
119
request := api.Draw{
120
Model: subject.Config.Model,
121
Prompt: inputText,
122
}
123
var err error
124
body, err = json.Marshal(request)
125
Expect(err).NotTo(HaveOccurred())
126
})
127
128
it("throws an error when the http call fails", func() {
129
mockCaller.EXPECT().
130
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
131
Return(nil, errors.New(errorText))
132
133
err := subject.GenerateImage(inputText, outputFile)
134
Expect(err).To(HaveOccurred())
135
Expect(err.Error()).To(ContainSubstring(errorText))
136
})
137
138
it("throws an error when no image data is returned", func() {
139
mockCaller.EXPECT().
140
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
141
Return([]byte(`{"data":[]}`), nil)
142
143
err := subject.GenerateImage(inputText, outputFile)
144
Expect(err).To(HaveOccurred())
145
Expect(err.Error()).To(ContainSubstring("no image data returned"))
146
})
147
148
it("throws an error when base64 is invalid", func() {
149
mockCaller.EXPECT().
150
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
151
Return([]byte(`{"data":[{"b64_json":"!!notbase64!!"}]}`), nil)
152
153
err := subject.GenerateImage(inputText, outputFile)
154
Expect(err).To(HaveOccurred())
155
Expect(err.Error()).To(ContainSubstring("failed to decode base64 image"))
156
})
157
158
it("throws an error when a file cannot be created", func() {
159
valid := base64.StdEncoding.EncodeToString([]byte("image-bytes"))
160
161
mockCaller.EXPECT().
162
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
163
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
164
165
mockWriter.EXPECT().Create(outputFile).Return(nil, errors.New(errorText))
166
167
err := subject.GenerateImage(inputText, outputFile)
168
Expect(err).To(HaveOccurred())
169
Expect(err.Error()).To(ContainSubstring(errorText))
170
})
171
172
it("throws an error when bytes cannot be written to the file", func() {
173
valid := base64.StdEncoding.EncodeToString([]byte("image-bytes"))
174
file, err := os.Open(os.DevNull)
175
Expect(err).NotTo(HaveOccurred())
176
defer file.Close()
177
178
mockCaller.EXPECT().
179
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
180
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
181
182
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
183
mockWriter.EXPECT().Write(file, []byte("image-bytes")).Return(errors.New(errorText))
184
185
err = subject.GenerateImage(inputText, outputFile)
186
Expect(err).To(HaveOccurred())
187
Expect(err.Error()).To(ContainSubstring(errorText))
188
})
189
190
it("succeeds when all steps complete", func() {
191
valid := base64.StdEncoding.EncodeToString([]byte("image-bytes"))
192
file, err := os.Open(os.DevNull)
193
Expect(err).NotTo(HaveOccurred())
194
defer file.Close()
195
196
mockCaller.EXPECT().
197
Post(subject.Config.URL+subject.Config.ImageGenerationsPath, body, false).
198
Return([]byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, valid)), nil)
199
200
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
201
mockWriter.EXPECT().Write(file, []byte("image-bytes")).Return(nil)
202
203
err = subject.GenerateImage(inputText, outputFile)
204
Expect(err).NotTo(HaveOccurred())
205
})
206
})
207
208
when("EditImage()", func() {
209
const (
210
inputText = "give the dog sunglasses"
211
inputFile = "dog.png"
212
outputFile = "dog_cool.png"
213
errorText = "mock error occurred"
214
)
215
216
var (
217
subject *client.Client
218
validB64 string
219
imageBytes = []byte("image-bytes")
220
respBytes []byte
221
)
222
223
it.Before(func() {
224
subject = factory.buildClientWithoutConfig()
225
validB64 = base64.StdEncoding.EncodeToString(imageBytes)
226
respBytes = []byte(fmt.Sprintf(`{"data":[{"b64_json":"%s"}]}`, validB64))
227
})
228
229
it("returns error when input file can't be opened", func() {
230
mockReader.EXPECT().Open(inputFile).Return(nil, errors.New(errorText))
231
232
err := subject.EditImage(inputText, inputFile, outputFile)
233
Expect(err).To(HaveOccurred())
234
Expect(err.Error()).To(ContainSubstring("failed to open input image"))
235
})
236
237
it("returns error on invalid mime type", func() {
238
file := openDummy()
239
mockReader.EXPECT().Open(inputFile).Return(file, nil).Times(2)
240
mockReader.EXPECT().ReadBufferFromFile(file).Return([]byte("not an image"), nil)
241
242
err := subject.EditImage(inputText, inputFile, outputFile)
243
Expect(err).To(HaveOccurred())
244
Expect(err.Error()).To(ContainSubstring("unsupported MIME type"))
245
})
246
247
it("returns error when HTTP call fails", func() {
248
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
249
return openDummy(), nil
250
}).Times(2)
251
252
mockReader.EXPECT().
253
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
254
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
255
256
mockCaller.EXPECT().
257
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
258
Return(nil, errors.New(errorText))
259
260
err := subject.EditImage(inputText, inputFile, outputFile)
261
Expect(err).To(HaveOccurred())
262
Expect(err.Error()).To(ContainSubstring("failed to edit image"))
263
})
264
265
it("returns error when base64 is invalid", func() {
266
invalidResp := []byte(`{"data":[{"b64_json":"!notbase64"}]}`)
267
268
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
269
return openDummy(), nil
270
}).Times(2)
271
272
mockReader.EXPECT().
273
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
274
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
275
276
mockCaller.EXPECT().
277
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
278
Return(invalidResp, nil)
279
280
err := subject.EditImage(inputText, inputFile, outputFile)
281
Expect(err).To(HaveOccurred())
282
Expect(err.Error()).To(ContainSubstring("failed to decode base64 image"))
283
})
284
285
it("writes image when all steps succeed", func() {
286
file := openDummy()
287
mockReader.EXPECT().Open(inputFile).DoAndReturn(func(string) (*os.File, error) {
288
return openDummy(), nil
289
}).Times(2)
290
291
mockReader.EXPECT().
292
ReadBufferFromFile(gomock.AssignableToTypeOf(&os.File{})).
293
Return([]byte("\x89PNG\r\n\x1a\n"), nil)
294
295
mockCaller.EXPECT().
296
PostWithHeaders(gomock.Any(), gomock.Any(), gomock.Any()).
297
Return(respBytes, nil)
298
299
mockWriter.EXPECT().Create(outputFile).Return(file, nil)
300
mockWriter.EXPECT().Write(file, imageBytes).Return(nil)
301
302
err := subject.EditImage(inputText, inputFile, outputFile)
303
Expect(err).NotTo(HaveOccurred())
304
})
305
})
306
307
when("Transcribe()", func() {
308
const audioPath = "path/to/audio.wav"
309
const transcribedText = "Hello, this is a test."
310
311
it("returns an error if the audio file cannot be opened", func() {
312
subject := factory.buildClientWithoutConfig()
313
314
mockHistoryStore.EXPECT().Read().Return(nil, nil)
315
mockTimer.EXPECT().Now().Times(1)
316
317
mockReader.EXPECT().Open(audioPath).Return(nil, errors.New("cannot open"))
318
319
_, err := subject.Transcribe(audioPath)
320
Expect(err).To(HaveOccurred())
321
Expect(err.Error()).To(ContainSubstring("cannot open"))
322
})
323
324
it("returns an error if copying audio content fails", func() {
325
subject := factory.buildClientWithoutConfig()
326
327
mockHistoryStore.EXPECT().Read().Return(nil, nil)
328
mockTimer.EXPECT().Now().Times(1)
329
330
reader, writer, err := os.Pipe()
331
Expect(err).NotTo(HaveOccurred())
332
_ = writer.Close() // force EOF/copy failure behavior
333
334
mockReader.EXPECT().Open(audioPath).Return(reader, nil)
335
336
mockCaller.EXPECT().
337
PostWithHeaders(subject.Config.URL+subject.Config.TranscriptionsPath, gomock.Any(), gomock.Any())
338
339
_, err = subject.Transcribe(audioPath)
340
Expect(err).To(HaveOccurred())
341
Expect(err.Error()).To(ContainSubstring("failed"))
342
})
343
344
it("returns an error if the API call fails", func() {
345
subject := factory.buildClientWithoutConfig()
346
347
mockHistoryStore.EXPECT().Read().Return(nil, nil)
348
mockTimer.EXPECT().Now().Times(1)
349
350
file, err := os.Open(os.DevNull)
351
Expect(err).NotTo(HaveOccurred())
352
defer file.Close()
353
354
mockReader.EXPECT().Open(audioPath).Return(file, nil)
355
356
mockCaller.EXPECT().
357
PostWithHeaders(subject.Config.URL+subject.Config.TranscriptionsPath, gomock.Any(), gomock.Any()).
358
Return(nil, errors.New("network error"))
359
360
_, err = subject.Transcribe(audioPath)
361
Expect(err).To(HaveOccurred())
362
Expect(err.Error()).To(ContainSubstring("network error"))
363
})
364
365
it("returns the transcribed text when successful", func() {
366
subject := factory.buildClientWithoutConfig()
367
368
mockHistoryStore.EXPECT().Read().Return(nil, nil)
369
370
now := time.Now()
371
mockTimer.EXPECT().Now().Return(now).Times(3)
372
373
file, err := os.Open(os.DevNull)
374
Expect(err).NotTo(HaveOccurred())
375
defer file.Close()
376
377
mockReader.EXPECT().Open(audioPath).Return(file, nil)
378
379
resp := []byte(`{"text": "Hello, this is a test."}`)
380
mockCaller.EXPECT().
381
PostWithHeaders(subject.Config.URL+subject.Config.TranscriptionsPath, gomock.Any(), gomock.Any()).
382
Return(resp, nil)
383
384
expectedHistory := []history.History{
385
{
386
Message: api.Message{
387
Role: client.SystemRole,
388
Content: subject.Config.Role,
389
},
390
Timestamp: now,
391
},
392
{
393
Message: api.Message{
394
Role: client.UserRole,
395
Content: "[transcribe] audio.wav",
396
},
397
Timestamp: now,
398
},
399
{
400
Message: api.Message{
401
Role: client.AssistantRole,
402
Content: transcribedText,
403
},
404
Timestamp: now,
405
},
406
}
407
408
mockHistoryStore.EXPECT().Write(expectedHistory)
409
410
text, err := subject.Transcribe(audioPath)
411
Expect(err).NotTo(HaveOccurred())
412
Expect(text).To(Equal(transcribedText))
413
})
414
})
415
})
416
}
417
418
func openDummy() *os.File {
419
// Use os.Pipe to get an *os.File without needing a real disk file.
420
r, w, _ := os.Pipe()
421
go func() {
422
_, _ = io.Copy(w, bytes.NewBuffer([]byte("\x89PNG\r\n\x1a\n")))
423
_ = w.Close()
424
}()
425
return r
426
}
427
428