Path: blob/main/components/public-api-server/pkg/identityprovider/cache.go
2500 views
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.1// Licensed under the GNU Affero General Public License (AGPL).2// See License.AGPL.txt in the project root for license information.34package identityprovider56import (7"context"8"crypto/rsa"9"crypto/sha256"10"encoding/hex"11"encoding/json"12"fmt"13"math/rand"14"strings"15"sync"16"time"1718"github.com/gitpod-io/gitpod/common-go/log"19"github.com/redis/go-redis/v9"20"gopkg.in/square/go-jose.v2"21)2223// KeyCache caches public keys to ensure they're returned with the JWKS as long24// as there are valid tokens out there using those keys.25//26// PoC Note: in production this cache would likely be implemented using Redis or the database.27type KeyCache interface {28// Set rotates the current key29Set(ctx context.Context, current *rsa.PrivateKey) error3031// Signer produces a new key signer or nil if Set() hasn't been called yet32Signer(ctx context.Context) (jose.Signer, error)3334// PublicKeys returns all un-expired public keys as JSON-encoded *jose.JSONWebKeySet.35// This function returns the JSON-encoded form directly instead of the *jose.JSONWebKeySet36// to allow for persisted JSON implementations of this interface.37PublicKeys(ctx context.Context) ([]byte, error)38}3940type inMemoryKey struct {41ID string42Created time.Time43Key *rsa.PublicKey44}4546func NewInMemoryCache() *InMemoryCache {47return &InMemoryCache{48keys: make(map[string]*inMemoryKey),49}50}5152type InMemoryCache struct {53mu sync.RWMutex54current *rsa.PrivateKey55currentID string5657keys map[string]*inMemoryKey58}5960// Set rotates the current key61func (imc *InMemoryCache) Set(ctx context.Context, current *rsa.PrivateKey) error {62imc.mu.Lock()63defer imc.mu.Unlock()6465id := fmt.Sprintf("id%d%d", time.Now().Unix(), rand.Int())66imc.currentID = id67imc.current = current68imc.keys[id] = &inMemoryKey{69ID: id,70Created: time.Now(),71Key: ¤t.PublicKey,72}73return nil74}7576// Signer produces a new key signer or nil if Set() hasn't been called yet77func (imc *InMemoryCache) Signer(ctx context.Context) (jose.Signer, error) {78if imc.current == nil {79return nil, nil80}8182return jose.NewSigner(jose.SigningKey{83Algorithm: jose.RS256,84Key: imc.current,85}, nil)86}8788// PublicKeys returns all un-expired public keys89func (imc *InMemoryCache) PublicKeys(ctx context.Context) ([]byte, error) {90imc.mu.RLock()91defer imc.mu.RUnlock()9293var jwks jose.JSONWebKeySet94for _, key := range imc.keys {95jwks.Keys = append(jwks.Keys, jose.JSONWebKey{96Key: key.Key,97KeyID: key.ID,98Algorithm: string(jose.RS256),99Use: "sig",100})101}102103return json.Marshal(jwks)104}105106const (107redisCacheDefaultTTL = 1 * time.Hour108redisIDPKeyPrefix = "idp:keys:"109)110111type RedisCache struct {112Client *redis.Client113114keyID func(current *rsa.PrivateKey) string115mu sync.RWMutex116current *rsa.PrivateKey117currentID string118}119120type redisCacheOpt struct {121refreshPeriod time.Duration122}123124type redisCacheOption func(*redisCacheOpt)125126func WithRefreshPeriod(t time.Duration) redisCacheOption {127return func(opt *redisCacheOpt) {128opt.refreshPeriod = t129}130}131132func NewRedisCache(ctx context.Context, client *redis.Client, opts ...redisCacheOption) *RedisCache {133opt := &redisCacheOpt{134refreshPeriod: 10 * time.Minute,135}136for _, o := range opts {137o(opt)138}139cache := &RedisCache{140Client: client,141keyID: defaultKeyID,142}143go cache.sync(ctx, opt.refreshPeriod)144return cache145}146147func defaultKeyID(current *rsa.PrivateKey) string {148hashed := sha256.Sum256(current.N.Bytes())149return fmt.Sprintf("id-%s", hex.EncodeToString(hashed[:]))150}151152// PublicKeys implements KeyCache153func (rc *RedisCache) PublicKeys(ctx context.Context) ([]byte, error) {154var (155res = []byte("{\"keys\":[")156first = true157hasCurrentKey = false158)159160if rc.current != nil && rc.currentID != "" {161hasCurrentKey = true162fc, err := serializePublicKeyAsJSONWebKey(rc.currentID, &rc.current.PublicKey)163if err != nil {164return nil, err165}166res = append(res, fc...)167first = false168}169170iter := rc.Client.Scan(ctx, 0, redisIDPKeyPrefix+"*", 0).Iterator()171for iter.Next(ctx) {172idx := iter.Val()173if hasCurrentKey && strings.HasSuffix(idx, rc.currentID) {174// We've already added the public key we hold in memory175continue176}177key, err := rc.Client.Get(ctx, idx).Result()178if err != nil {179return nil, err180}181182if !first {183res = append(res, []byte(",")...)184}185res = append(res, []byte(key)...)186first = false187}188if err := iter.Err(); err != nil {189return nil, err190}191res = append(res, []byte("]}")...)192return res, nil193}194195func serializePublicKeyAsJSONWebKey(keyID string, key *rsa.PublicKey) ([]byte, error) {196publicKey := jose.JSONWebKey{197Key: key,198KeyID: keyID,199Algorithm: string(jose.RS256),200Use: "sig",201}202return json.Marshal(publicKey)203}204205// Set implements KeyCache206func (rc *RedisCache) Set(ctx context.Context, current *rsa.PrivateKey) error {207rc.mu.Lock()208defer rc.mu.Unlock()209210err := rc.persistPublicKey(ctx, current)211if err != nil {212return err213}214rc.currentID = rc.keyID(current)215rc.current = current216217return nil218}219220func (rc *RedisCache) persistPublicKey(ctx context.Context, current *rsa.PrivateKey) error {221id := rc.keyID(current)222223publicKeyJSON, err := serializePublicKeyAsJSONWebKey(id, ¤t.PublicKey)224if err != nil {225return err226}227228redisKey := fmt.Sprintf("%s%s", redisIDPKeyPrefix, id)229err = rc.Client.Set(ctx, redisKey, string(publicKeyJSON), redisCacheDefaultTTL).Err()230if err != nil {231return err232}233234return nil235}236237// Signer implements KeyCache238func (rc *RedisCache) Signer(ctx context.Context) (jose.Signer, error) {239if rc.current == nil {240return nil, nil241}242243if err := rc.reconcile(ctx); err != nil {244return nil, err245}246247return jose.NewSigner(jose.SigningKey{248Algorithm: jose.RS256,249Key: rc.current,250}, &jose.SignerOptions{251ExtraHeaders: map[jose.HeaderKey]interface{}{252jose.HeaderKey("kid"): rc.currentID,253},254})255}256257func (rc *RedisCache) reconcile(ctx context.Context) error {258if rc.current == nil {259return nil260}261262resp := rc.Client.Expire(ctx, redisIDPKeyPrefix+rc.currentID, redisCacheDefaultTTL)263if err := resp.Err(); err != nil {264log.WithField("keyID", rc.currentID).WithError(err).Warn("cannot extend cached IDP public key TTL")265}266if !resp.Val() {267log.WithField("keyID", rc.currentID).Warn("cannot extend cached IDP public key TTL - trying to repersist")268err := rc.persistPublicKey(ctx, rc.current)269if err != nil {270log.WithField("keyID", rc.currentID).WithError(err).Error("cannot repersist public key")271return err272}273}274return nil275}276277func (rc *RedisCache) sync(ctx context.Context, period time.Duration) {278ticker := time.NewTicker(period)279for {280select {281case <-ctx.Done():282return283case <-ticker.C:284_ = rc.reconcile(ctx)285}286}287}288289var _ KeyCache = ((*RedisCache)(nil))290291292