Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/test/integration/helpers_test.go
2649 views
1
package integration_test
2
3
import (
4
"errors"
5
"fmt"
6
"github.com/kardolus/chatgpt-cli/config"
7
"github.com/kardolus/chatgpt-cli/test"
8
"github.com/onsi/gomega/gexec"
9
"io"
10
"net/http"
11
"strings"
12
"sync"
13
)
14
15
const expectedToken = "valid-api-key"
16
17
var (
18
onceBuild sync.Once
19
onceServe sync.Once
20
serverReady = make(chan struct{})
21
binaryPath string
22
)
23
24
func buildBinary() error {
25
var err error
26
onceBuild.Do(func() {
27
binaryPath, err = gexec.Build(
28
"github.com/kardolus/chatgpt-cli/cmd/chatgpt",
29
"-ldflags",
30
fmt.Sprintf("-X main.GitCommit=%s -X main.GitVersion=%s -X main.ServiceURL=%s", gitCommit, gitVersion, serviceURL))
31
})
32
return err
33
}
34
35
func curl(url string) (string, error) {
36
resp, err := http.Get(url)
37
if err != nil {
38
return "", err
39
}
40
defer resp.Body.Close()
41
42
data, err := io.ReadAll(resp.Body)
43
if err != nil {
44
return "", err
45
}
46
47
return string(data), nil
48
}
49
50
func runMockServer() error {
51
var (
52
defaults config.Config
53
err error
54
)
55
56
onceServe.Do(func() {
57
go func() {
58
defaults = config.NewStore().ReadDefaults()
59
60
http.HandleFunc("/ping", getPing)
61
http.HandleFunc(defaults.CompletionsPath, postCompletions)
62
http.HandleFunc(defaults.ModelsPath, getModels)
63
close(serverReady)
64
err = http.ListenAndServe(servicePort, nil)
65
}()
66
})
67
<-serverReady
68
return err
69
}
70
71
func getPing(w http.ResponseWriter, r *http.Request) {
72
if r.Method != http.MethodGet {
73
w.WriteHeader(http.StatusMethodNotAllowed)
74
return
75
}
76
_, _ = w.Write([]byte("pong"))
77
}
78
79
func getModels(w http.ResponseWriter, r *http.Request) {
80
if err := validateRequest(w, r, http.MethodGet); err != nil {
81
fmt.Printf("invalid request: %s\n", err.Error())
82
return
83
}
84
85
if err := checkBearerToken(r, expectedToken); err != nil {
86
http.Error(w, creatAuthError(), http.StatusUnauthorized)
87
return
88
}
89
90
const modelFile = "models.json"
91
response, err := test.FileToBytes(modelFile)
92
if err != nil {
93
fmt.Printf("error reading %s: %s\n", modelFile, err.Error())
94
return
95
}
96
_, _ = w.Write(response)
97
}
98
99
func postCompletions(w http.ResponseWriter, r *http.Request) {
100
if err := validateRequest(w, r, http.MethodPost); err != nil {
101
fmt.Printf("invalid request: %s\n", err.Error())
102
return
103
}
104
105
if err := checkBearerToken(r, expectedToken); err != nil {
106
http.Error(w, creatAuthError(), http.StatusUnauthorized)
107
return
108
}
109
110
const completionsFile = "completions.json"
111
response, err := test.FileToBytes(completionsFile)
112
if err != nil {
113
fmt.Printf("error reading %s: %s\n", completionsFile, err.Error())
114
return
115
}
116
_, _ = w.Write(response)
117
}
118
119
func checkBearerToken(r *http.Request, expectedToken string) error {
120
authHeader := r.Header.Get("Authorization")
121
if authHeader == "" {
122
return errors.New("missing Authorization header")
123
}
124
125
splitToken := strings.Split(authHeader, "Bearer ")
126
if len(splitToken) != 2 {
127
return errors.New("malformed Authorization header")
128
}
129
130
requestToken := splitToken[1]
131
if requestToken != expectedToken {
132
return errors.New("invalid token")
133
}
134
135
return nil
136
}
137
138
func creatAuthError() string {
139
const errorFile = "error.json"
140
141
response, err := test.FileToBytes(errorFile)
142
if err != nil {
143
fmt.Printf("error reading %s: %s\n", errorFile, err.Error())
144
return ""
145
}
146
147
return string(response)
148
}
149
150
func validateRequest(w http.ResponseWriter, r *http.Request, allowedMethod string) error {
151
if r.Method != allowedMethod {
152
w.WriteHeader(http.StatusMethodNotAllowed)
153
return errors.New("method not allowed")
154
}
155
156
if !strings.Contains(r.Header.Get("Authorization"), "Bearer") {
157
w.WriteHeader(http.StatusBadRequest)
158
return errors.New("bad request")
159
}
160
161
return nil
162
}
163
164