Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
gitpod-io
GitHub Repository: gitpod-io/gitpod
Path: blob/main/components/public-api-server/pkg/identityprovider/cache.go
2500 views
1
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2
// Licensed under the GNU Affero General Public License (AGPL).
3
// See License.AGPL.txt in the project root for license information.
4
5
package identityprovider
6
7
import (
8
"context"
9
"crypto/rsa"
10
"crypto/sha256"
11
"encoding/hex"
12
"encoding/json"
13
"fmt"
14
"math/rand"
15
"strings"
16
"sync"
17
"time"
18
19
"github.com/gitpod-io/gitpod/common-go/log"
20
"github.com/redis/go-redis/v9"
21
"gopkg.in/square/go-jose.v2"
22
)
23
24
// KeyCache caches public keys to ensure they're returned with the JWKS as long
25
// as there are valid tokens out there using those keys.
26
//
27
// PoC Note: in production this cache would likely be implemented using Redis or the database.
28
type KeyCache interface {
29
// Set rotates the current key
30
Set(ctx context.Context, current *rsa.PrivateKey) error
31
32
// Signer produces a new key signer or nil if Set() hasn't been called yet
33
Signer(ctx context.Context) (jose.Signer, error)
34
35
// PublicKeys returns all un-expired public keys as JSON-encoded *jose.JSONWebKeySet.
36
// This function returns the JSON-encoded form directly instead of the *jose.JSONWebKeySet
37
// to allow for persisted JSON implementations of this interface.
38
PublicKeys(ctx context.Context) ([]byte, error)
39
}
40
41
type inMemoryKey struct {
42
ID string
43
Created time.Time
44
Key *rsa.PublicKey
45
}
46
47
func NewInMemoryCache() *InMemoryCache {
48
return &InMemoryCache{
49
keys: make(map[string]*inMemoryKey),
50
}
51
}
52
53
type InMemoryCache struct {
54
mu sync.RWMutex
55
current *rsa.PrivateKey
56
currentID string
57
58
keys map[string]*inMemoryKey
59
}
60
61
// Set rotates the current key
62
func (imc *InMemoryCache) Set(ctx context.Context, current *rsa.PrivateKey) error {
63
imc.mu.Lock()
64
defer imc.mu.Unlock()
65
66
id := fmt.Sprintf("id%d%d", time.Now().Unix(), rand.Int())
67
imc.currentID = id
68
imc.current = current
69
imc.keys[id] = &inMemoryKey{
70
ID: id,
71
Created: time.Now(),
72
Key: &current.PublicKey,
73
}
74
return nil
75
}
76
77
// Signer produces a new key signer or nil if Set() hasn't been called yet
78
func (imc *InMemoryCache) Signer(ctx context.Context) (jose.Signer, error) {
79
if imc.current == nil {
80
return nil, nil
81
}
82
83
return jose.NewSigner(jose.SigningKey{
84
Algorithm: jose.RS256,
85
Key: imc.current,
86
}, nil)
87
}
88
89
// PublicKeys returns all un-expired public keys
90
func (imc *InMemoryCache) PublicKeys(ctx context.Context) ([]byte, error) {
91
imc.mu.RLock()
92
defer imc.mu.RUnlock()
93
94
var jwks jose.JSONWebKeySet
95
for _, key := range imc.keys {
96
jwks.Keys = append(jwks.Keys, jose.JSONWebKey{
97
Key: key.Key,
98
KeyID: key.ID,
99
Algorithm: string(jose.RS256),
100
Use: "sig",
101
})
102
}
103
104
return json.Marshal(jwks)
105
}
106
107
const (
108
redisCacheDefaultTTL = 1 * time.Hour
109
redisIDPKeyPrefix = "idp:keys:"
110
)
111
112
type RedisCache struct {
113
Client *redis.Client
114
115
keyID func(current *rsa.PrivateKey) string
116
mu sync.RWMutex
117
current *rsa.PrivateKey
118
currentID string
119
}
120
121
type redisCacheOpt struct {
122
refreshPeriod time.Duration
123
}
124
125
type redisCacheOption func(*redisCacheOpt)
126
127
func WithRefreshPeriod(t time.Duration) redisCacheOption {
128
return func(opt *redisCacheOpt) {
129
opt.refreshPeriod = t
130
}
131
}
132
133
func NewRedisCache(ctx context.Context, client *redis.Client, opts ...redisCacheOption) *RedisCache {
134
opt := &redisCacheOpt{
135
refreshPeriod: 10 * time.Minute,
136
}
137
for _, o := range opts {
138
o(opt)
139
}
140
cache := &RedisCache{
141
Client: client,
142
keyID: defaultKeyID,
143
}
144
go cache.sync(ctx, opt.refreshPeriod)
145
return cache
146
}
147
148
func defaultKeyID(current *rsa.PrivateKey) string {
149
hashed := sha256.Sum256(current.N.Bytes())
150
return fmt.Sprintf("id-%s", hex.EncodeToString(hashed[:]))
151
}
152
153
// PublicKeys implements KeyCache
154
func (rc *RedisCache) PublicKeys(ctx context.Context) ([]byte, error) {
155
var (
156
res = []byte("{\"keys\":[")
157
first = true
158
hasCurrentKey = false
159
)
160
161
if rc.current != nil && rc.currentID != "" {
162
hasCurrentKey = true
163
fc, err := serializePublicKeyAsJSONWebKey(rc.currentID, &rc.current.PublicKey)
164
if err != nil {
165
return nil, err
166
}
167
res = append(res, fc...)
168
first = false
169
}
170
171
iter := rc.Client.Scan(ctx, 0, redisIDPKeyPrefix+"*", 0).Iterator()
172
for iter.Next(ctx) {
173
idx := iter.Val()
174
if hasCurrentKey && strings.HasSuffix(idx, rc.currentID) {
175
// We've already added the public key we hold in memory
176
continue
177
}
178
key, err := rc.Client.Get(ctx, idx).Result()
179
if err != nil {
180
return nil, err
181
}
182
183
if !first {
184
res = append(res, []byte(",")...)
185
}
186
res = append(res, []byte(key)...)
187
first = false
188
}
189
if err := iter.Err(); err != nil {
190
return nil, err
191
}
192
res = append(res, []byte("]}")...)
193
return res, nil
194
}
195
196
func serializePublicKeyAsJSONWebKey(keyID string, key *rsa.PublicKey) ([]byte, error) {
197
publicKey := jose.JSONWebKey{
198
Key: key,
199
KeyID: keyID,
200
Algorithm: string(jose.RS256),
201
Use: "sig",
202
}
203
return json.Marshal(publicKey)
204
}
205
206
// Set implements KeyCache
207
func (rc *RedisCache) Set(ctx context.Context, current *rsa.PrivateKey) error {
208
rc.mu.Lock()
209
defer rc.mu.Unlock()
210
211
err := rc.persistPublicKey(ctx, current)
212
if err != nil {
213
return err
214
}
215
rc.currentID = rc.keyID(current)
216
rc.current = current
217
218
return nil
219
}
220
221
func (rc *RedisCache) persistPublicKey(ctx context.Context, current *rsa.PrivateKey) error {
222
id := rc.keyID(current)
223
224
publicKeyJSON, err := serializePublicKeyAsJSONWebKey(id, &current.PublicKey)
225
if err != nil {
226
return err
227
}
228
229
redisKey := fmt.Sprintf("%s%s", redisIDPKeyPrefix, id)
230
err = rc.Client.Set(ctx, redisKey, string(publicKeyJSON), redisCacheDefaultTTL).Err()
231
if err != nil {
232
return err
233
}
234
235
return nil
236
}
237
238
// Signer implements KeyCache
239
func (rc *RedisCache) Signer(ctx context.Context) (jose.Signer, error) {
240
if rc.current == nil {
241
return nil, nil
242
}
243
244
if err := rc.reconcile(ctx); err != nil {
245
return nil, err
246
}
247
248
return jose.NewSigner(jose.SigningKey{
249
Algorithm: jose.RS256,
250
Key: rc.current,
251
}, &jose.SignerOptions{
252
ExtraHeaders: map[jose.HeaderKey]interface{}{
253
jose.HeaderKey("kid"): rc.currentID,
254
},
255
})
256
}
257
258
func (rc *RedisCache) reconcile(ctx context.Context) error {
259
if rc.current == nil {
260
return nil
261
}
262
263
resp := rc.Client.Expire(ctx, redisIDPKeyPrefix+rc.currentID, redisCacheDefaultTTL)
264
if err := resp.Err(); err != nil {
265
log.WithField("keyID", rc.currentID).WithError(err).Warn("cannot extend cached IDP public key TTL")
266
}
267
if !resp.Val() {
268
log.WithField("keyID", rc.currentID).Warn("cannot extend cached IDP public key TTL - trying to repersist")
269
err := rc.persistPublicKey(ctx, rc.current)
270
if err != nil {
271
log.WithField("keyID", rc.currentID).WithError(err).Error("cannot repersist public key")
272
return err
273
}
274
}
275
return nil
276
}
277
278
func (rc *RedisCache) sync(ctx context.Context, period time.Duration) {
279
ticker := time.NewTicker(period)
280
for {
281
select {
282
case <-ctx.Done():
283
return
284
case <-ticker.C:
285
_ = rc.reconcile(ctx)
286
}
287
}
288
}
289
290
var _ KeyCache = ((*RedisCache)(nil))
291
292