Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
gitpod-io
GitHub Repository: gitpod-io/gitpod
Path: blob/main/components/public-api-server/pkg/identityprovider/cache_test.go
2500 views
1
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2
// Licensed under the GNU Affero General Public License (AGPL).
3
// See License.AGPL.txt in the project root for license information.
4
5
package identityprovider
6
7
import (
8
"context"
9
"crypto/rand"
10
"crypto/rsa"
11
"encoding/base64"
12
"encoding/json"
13
"sort"
14
"testing"
15
"time"
16
17
"github.com/alicebob/miniredis/v2"
18
"github.com/google/go-cmp/cmp"
19
"github.com/redis/go-redis/v9"
20
"gopkg.in/square/go-jose.v2"
21
)
22
23
func testKeyID(k *rsa.PrivateKey) string {
24
return base64.RawURLEncoding.EncodeToString(k.PublicKey.N.Bytes())[0:12]
25
}
26
27
func sortKeys(jwks *jose.JSONWebKeySet) {
28
sort.Slice(jwks.Keys, func(i, j int) bool {
29
var (
30
ki = jwks.Keys[i]
31
kj = jwks.Keys[j]
32
)
33
return ki.KeyID < kj.KeyID
34
})
35
}
36
37
func TestRedisCachePublicKeys(t *testing.T) {
38
var (
39
jwks jose.JSONWebKeySet
40
threeKeys []*rsa.PrivateKey
41
)
42
for i := 0; i < 3; i++ {
43
key, err := rsa.GenerateKey(rand.Reader, 2048)
44
if err != nil {
45
panic(err)
46
}
47
threeKeys = append(threeKeys, key)
48
jwks.Keys = append(jwks.Keys, jose.JSONWebKey{
49
Key: &key.PublicKey,
50
Algorithm: string(jose.RS256),
51
KeyID: testKeyID(key),
52
Use: "sig",
53
})
54
}
55
sortKeys(&jwks)
56
threeKeysExpectation, err := json.Marshal(jwks)
57
if err != nil {
58
panic(err)
59
}
60
61
type Expectation struct {
62
Error string
63
Response []byte
64
}
65
type Test struct {
66
Name string
67
Keys []*rsa.PrivateKey
68
StateMod func(*redis.Client) error
69
Expectation Expectation
70
}
71
tests := []Test{
72
{
73
Name: "redis down",
74
Keys: threeKeys,
75
StateMod: func(c *redis.Client) error {
76
return c.FlushAll(context.Background()).Err()
77
},
78
Expectation: Expectation{
79
Response: func() []byte {
80
fc, err := serializePublicKeyAsJSONWebKey(testKeyID(threeKeys[2]), &threeKeys[2].PublicKey)
81
if err != nil {
82
panic(err)
83
}
84
return []byte(`{"keys":[` + string(fc) + `]}`)
85
}(),
86
},
87
},
88
{
89
Name: "no keys",
90
Expectation: Expectation{
91
Response: []byte(`{"keys":[]}`),
92
},
93
},
94
{
95
Name: "no key in memory",
96
StateMod: func(c *redis.Client) error {
97
return c.Set(context.Background(), redisIDPKeyPrefix+"foo", `{"use":"sig","kty":"RSA","kid":"fpp","alg":"RS256","n":"VGVsbCBDaHJpcyB5b3UgZm91bmQgdGhpcyAtIGRyaW5rJ3Mgb24gbWU","e":"AQAB"}`, 0).Err()
98
},
99
Expectation: Expectation{
100
Response: []byte(`{"keys":[{"use":"sig","kty":"RSA","kid":"fpp","alg":"RS256","n":"VGVsbCBDaHJpcyB5b3UgZm91bmQgdGhpcyAtIGRyaW5rJ3Mgb24gbWU","e":"AQAB"}]}`),
101
},
102
},
103
{
104
Name: "multiple keys",
105
Keys: threeKeys,
106
Expectation: Expectation{
107
Response: threeKeysExpectation,
108
},
109
},
110
}
111
112
for _, test := range tests {
113
t.Run(test.Name, func(t *testing.T) {
114
s := miniredis.RunT(t)
115
ctx, cancel := context.WithCancel(context.Background())
116
t.Cleanup(func() {
117
cancel()
118
})
119
client := redis.NewClient(&redis.Options{Addr: s.Addr()})
120
cache := NewRedisCache(ctx, client)
121
cache.keyID = testKeyID
122
for _, key := range test.Keys {
123
err := cache.Set(context.Background(), key)
124
if err != nil {
125
t.Fatal(err)
126
}
127
}
128
if test.StateMod != nil {
129
err := test.StateMod(client)
130
if err != nil {
131
t.Fatal(err)
132
}
133
}
134
135
var (
136
act Expectation
137
err error
138
)
139
fc, err := cache.PublicKeys(context.Background())
140
if err != nil {
141
act.Error = err.Error()
142
}
143
if len(fc) > 0 {
144
var res jose.JSONWebKeySet
145
err = json.Unmarshal(fc, &res)
146
if err != nil {
147
t.Fatal(err)
148
}
149
sortKeys(&res)
150
act.Response, err = json.Marshal(&res)
151
if err != nil {
152
t.Fatal(err)
153
}
154
}
155
156
if diff := cmp.Diff(test.Expectation, act); diff != "" {
157
t.Errorf("PublicKeys() mismatch (-want +got):\n%s", diff)
158
}
159
})
160
}
161
}
162
163
func TestRedisCacheSigner(t *testing.T) {
164
s := miniredis.RunT(t)
165
client := redis.NewClient(&redis.Options{Addr: s.Addr()})
166
cache := NewRedisCache(context.Background(), client)
167
168
sig, err := cache.Signer(context.Background())
169
if sig != nil {
170
t.Error("Signer() returned a signer despite having no key set")
171
}
172
if err != nil {
173
t.Errorf("Signer() returned an despite having no key set: %v", err)
174
}
175
176
key, err := rsa.GenerateKey(rand.Reader, 2048)
177
if err != nil {
178
t.Fatal(err)
179
}
180
err = cache.Set(context.Background(), key)
181
if err != nil {
182
t.Fatalf("RedisCache failed to Set current key but shouldn't have: %v", err)
183
}
184
185
sig, err = cache.Signer(context.Background())
186
if sig == nil {
187
t.Error("Signer() returned nil even though a key was set")
188
}
189
if err != nil {
190
t.Error("Signer() returned an error even though a key was set")
191
}
192
193
signature, err := sig.Sign([]byte("foo"))
194
if err != nil {
195
t.Fatal(err)
196
}
197
_, err = signature.Verify(&key.PublicKey)
198
if err != nil {
199
t.Errorf("Returned signer does not sign with currently set key")
200
}
201
202
err = client.FlushAll(context.Background()).Err()
203
if err != nil {
204
t.Fatal(err)
205
}
206
_, err = cache.Signer(context.Background())
207
if err != nil {
208
t.Fatal(err)
209
}
210
keys := client.Keys(context.Background(), redisIDPKeyPrefix+"*").Val()
211
if len(keys) == 0 {
212
t.Error("getting a new signer did not repersist the key")
213
}
214
}
215
216
func TestRedisPeriodicallySync(t *testing.T) {
217
s := miniredis.RunT(t)
218
client := redis.NewClient(&redis.Options{Addr: s.Addr()})
219
cache := NewRedisCache(context.Background(), client, WithRefreshPeriod(1*time.Second))
220
221
key, err := rsa.GenerateKey(rand.Reader, 2048)
222
if err != nil {
223
t.Fatal(err)
224
}
225
err = cache.Set(context.Background(), key)
226
if err != nil {
227
t.Fatalf("RedisCache failed to Set current key but shouldn't have: %v", err)
228
}
229
err = client.FlushAll(context.Background()).Err()
230
if err != nil {
231
t.Fatal(err)
232
}
233
time.Sleep(3 * time.Second)
234
keys := client.Keys(context.Background(), redisIDPKeyPrefix+"*").Val()
235
if len(keys) == 0 {
236
t.Error("redis periodically sync won't work")
237
}
238
}
239
240