Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
gitpod-io
GitHub Repository: gitpod-io/gitpod
Path: blob/main/components/public-api-server/pkg/oidc/service.go
2500 views
1
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2
// Licensed under the GNU Affero General Public License (AGPL).
3
// See License.AGPL.txt in the project root for license information.
4
5
package oidc
6
7
import (
8
"bytes"
9
"context"
10
"crypto/rand"
11
"encoding/base64"
12
"encoding/json"
13
"fmt"
14
"io"
15
"net/http"
16
"time"
17
18
"github.com/coreos/go-oidc/v3/oidc"
19
goidc "github.com/coreos/go-oidc/v3/oidc"
20
"github.com/gitpod-io/gitpod/common-go/log"
21
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
22
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
23
"github.com/golang-jwt/jwt/v5"
24
"github.com/google/cel-go/cel"
25
"github.com/google/cel-go/checker/decls"
26
"github.com/google/uuid"
27
"golang.org/x/oauth2"
28
"google.golang.org/grpc/codes"
29
"google.golang.org/grpc/status"
30
"gorm.io/gorm"
31
)
32
33
type Service struct {
34
dbConn *gorm.DB
35
cipher db.Cipher
36
37
// jwts
38
stateExpiry time.Duration
39
signerVerifier jws.SignerVerifier
40
41
sessionServiceAddress string
42
43
// TODO(at) remove by enhancing test setups
44
skipVerifyIdToken bool
45
}
46
47
type ClientConfig struct {
48
ID string
49
OrganizationID string
50
Issuer string
51
Active bool
52
OAuth2Config *oauth2.Config
53
VerifierConfig *goidc.Config
54
CelExpression string
55
UsePKCE bool
56
}
57
58
type StartParams struct {
59
State string
60
Nonce string
61
CodeVerifier string
62
AuthCodeURL string
63
}
64
65
type AuthFlowResult struct {
66
IDToken *goidc.IDToken `json:"idToken"`
67
Claims map[string]interface{} `json:"claims"`
68
}
69
70
func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, signerVerifier jws.SignerVerifier, stateExpiry time.Duration) *Service {
71
return &Service{
72
sessionServiceAddress: sessionServiceAddress,
73
74
dbConn: dbConn,
75
cipher: cipher,
76
77
signerVerifier: signerVerifier,
78
stateExpiry: stateExpiry,
79
}
80
}
81
82
func (s *Service) getStartParams(config *ClientConfig, redirectURL string, stateParams StateParams) (*StartParams, error) {
83
// the `state` is supposed to be passed through unmodified by the IdP.
84
state, err := s.encodeStateParam(stateParams)
85
if err != nil {
86
return nil, fmt.Errorf("failed to encode state")
87
}
88
89
// number used once
90
nonce, err := randString(32)
91
if err != nil {
92
return nil, fmt.Errorf("failed to create nonce")
93
}
94
95
// Configuring `AuthCodeOption`s, e.g. nonce
96
config.OAuth2Config.RedirectURL = redirectURL
97
98
opts := []oauth2.AuthCodeOption{goidc.Nonce(nonce)}
99
var verifier string
100
if config.UsePKCE {
101
verifier = oauth2.GenerateVerifier()
102
opts = append(opts, oauth2.S256ChallengeOption(verifier))
103
}
104
105
authCodeURL := config.OAuth2Config.AuthCodeURL(state, opts...)
106
107
return &StartParams{
108
AuthCodeURL: authCodeURL,
109
State: state,
110
Nonce: nonce,
111
CodeVerifier: verifier,
112
}, nil
113
}
114
115
func (s *Service) encodeStateParam(state StateParams) (string, error) {
116
now := time.Now().UTC()
117
expiry := now.Add(s.stateExpiry)
118
token := NewStateJWT(state, now, expiry)
119
120
signed, err := s.signerVerifier.Sign(token)
121
if err != nil {
122
return "", fmt.Errorf("failed to sign jwt: %w", err)
123
}
124
return signed, nil
125
}
126
127
func (s *Service) decodeStateParam(encodedToken string) (StateParams, error) {
128
claims := &StateClaims{}
129
_, err := s.signerVerifier.Verify(encodedToken, claims)
130
if err != nil {
131
return StateParams{}, fmt.Errorf("failed to verify state token: %w", err)
132
}
133
134
return claims.StateParams, nil
135
}
136
137
func randString(size int) (string, error) {
138
b := make([]byte, size)
139
if _, err := io.ReadFull(rand.Reader, b); err != nil {
140
return "", err
141
}
142
return base64.RawURLEncoding.EncodeToString(b), nil
143
}
144
145
func (s *Service) getClientConfigFromStartRequest(r *http.Request) (*ClientConfig, error) {
146
orgSlug := r.URL.Query().Get("orgSlug")
147
idParam := r.URL.Query().Get("id")
148
149
// if no org slug is given, we assume the request is for the default org
150
if orgSlug == "" && idParam == "" {
151
org, err := db.GetSingleOrganizationWithActiveSSO(r.Context(), s.dbConn)
152
if err != nil {
153
return nil, fmt.Errorf("Failed to find team: %w", err)
154
}
155
orgSlug = org.Slug
156
}
157
if orgSlug != "" {
158
dbEntry, err := db.GetActiveOIDCClientConfigByOrgSlug(r.Context(), s.dbConn, orgSlug)
159
if err != nil {
160
return nil, fmt.Errorf("Failed to find OIDC clients: %w", err)
161
}
162
163
config, err := s.convertClientConfig(r.Context(), dbEntry)
164
if err != nil {
165
return nil, fmt.Errorf("Failed to find OIDC clients: %w", err)
166
}
167
168
return &config, nil
169
}
170
171
if idParam == "" {
172
return nil, fmt.Errorf("missing id parameter")
173
}
174
175
if idParam != "" {
176
config, err := s.getConfigById(r.Context(), idParam)
177
if err != nil {
178
return nil, err
179
}
180
return config, nil
181
}
182
183
return nil, fmt.Errorf("failed to find OIDC config")
184
}
185
186
func (s *Service) getClientConfigFromCallbackRequest(r *http.Request) (*ClientConfig, *StateParams, error) {
187
stateParam := r.URL.Query().Get("state")
188
if stateParam == "" {
189
return nil, nil, fmt.Errorf("missing state parameter")
190
}
191
192
state, err := s.decodeStateParam(stateParam)
193
if err != nil {
194
return nil, nil, fmt.Errorf("bad state param")
195
}
196
config, _ := s.getConfigById(r.Context(), state.ClientConfigID)
197
if config != nil {
198
return config, &state, nil
199
}
200
201
return nil, nil, fmt.Errorf("failed to find OIDC config on callback")
202
}
203
204
func (s *Service) activateAndVerifyClientConfig(ctx context.Context, config *ClientConfig) error {
205
uuid, err := uuid.Parse(config.ID)
206
if err != nil {
207
return err
208
}
209
err = db.VerifyClientConfig(ctx, s.dbConn, uuid)
210
if err != nil {
211
return err
212
}
213
return db.SetClientConfigActiviation(ctx, s.dbConn, uuid, true)
214
}
215
216
func (s *Service) markClientConfigAsVerified(ctx context.Context, config *ClientConfig) error {
217
uuid, err := uuid.Parse(config.ID)
218
if err != nil {
219
return err
220
}
221
return db.VerifyClientConfig(ctx, s.dbConn, uuid)
222
}
223
224
func (s *Service) getConfigById(ctx context.Context, id string) (*ClientConfig, error) {
225
uuid, err := uuid.Parse(id)
226
if err != nil {
227
return nil, err
228
}
229
dbEntry, err := db.GetOIDCClientConfig(ctx, s.dbConn, uuid)
230
if err != nil {
231
return nil, err
232
}
233
config, err := s.convertClientConfig(ctx, dbEntry)
234
if err != nil {
235
log.Log.WithError(err).Error("Failed to decrypt oidc client config.")
236
return nil, status.Errorf(codes.Internal, "Failed to decrypt OIDC client config.")
237
}
238
239
return &config, nil
240
}
241
242
func (s *Service) convertClientConfig(ctx context.Context, dbEntry db.OIDCClientConfig) (ClientConfig, error) {
243
spec, err := dbEntry.Data.Decrypt(s.cipher)
244
if err != nil {
245
log.Log.WithError(err).Error("Failed to decrypt oidc client config.")
246
return ClientConfig{}, status.Errorf(codes.Internal, "Failed to decrypt OIDC client config.")
247
}
248
249
provider, err := oidc.NewProvider(ctx, dbEntry.Issuer)
250
if err != nil {
251
return ClientConfig{}, err
252
}
253
254
return ClientConfig{
255
ID: dbEntry.ID.String(),
256
OrganizationID: dbEntry.OrganizationID.String(),
257
Issuer: dbEntry.Issuer,
258
Active: dbEntry.Active,
259
OAuth2Config: &oauth2.Config{
260
ClientID: spec.ClientID,
261
ClientSecret: spec.ClientSecret,
262
Endpoint: provider.Endpoint(),
263
Scopes: spec.Scopes,
264
},
265
CelExpression: spec.CelExpression,
266
UsePKCE: spec.UsePKCE,
267
VerifierConfig: &goidc.Config{
268
ClientID: spec.ClientID,
269
},
270
}, nil
271
}
272
273
type authenticateParams struct {
274
Config *ClientConfig
275
OAuth2Result *OAuth2Result
276
NonceCookieValue string
277
}
278
279
type CelExprError struct {
280
Msg string
281
Code string
282
}
283
284
func (e *CelExprError) Error() string {
285
return fmt.Sprintf("%s [%s]", e.Msg, e.Code)
286
}
287
288
func (s *Service) authenticate(ctx context.Context, params authenticateParams) (*AuthFlowResult, error) {
289
rawIDToken, ok := params.OAuth2Result.OAuth2Token.Extra("id_token").(string)
290
if !ok {
291
return nil, fmt.Errorf("id_token not found")
292
}
293
294
provider, err := oidc.NewProvider(ctx, params.Config.Issuer)
295
if err != nil {
296
return nil, fmt.Errorf("Failed to initialize provider.")
297
}
298
verifier := provider.Verifier(&goidc.Config{
299
ClientID: params.Config.OAuth2Config.ClientID,
300
})
301
302
idToken, err := verifier.Verify(ctx, rawIDToken)
303
if err != nil {
304
return nil, fmt.Errorf("failed to verify id_token: %w", err)
305
}
306
if idToken.Nonce != params.NonceCookieValue {
307
return nil, fmt.Errorf("nonce mismatch")
308
}
309
validatedClaims, err := s.validateRequiredClaims(ctx, provider, idToken)
310
if err != nil {
311
return nil, fmt.Errorf("failed to validate required claims: %w", err)
312
}
313
validatedCelExpression, err := s.verifyCelExpression(ctx, params.Config.CelExpression, validatedClaims)
314
if err != nil {
315
return nil, err
316
}
317
if !validatedCelExpression {
318
return nil, &CelExprError{Msg: "CEL expression did not evaluate to true", Code: "CEL:EVAL_FALSE"}
319
}
320
return &AuthFlowResult{
321
IDToken: idToken,
322
Claims: validatedClaims,
323
}, nil
324
}
325
326
func (s *Service) createSession(ctx context.Context, flowResult *AuthFlowResult, clientConfig *ClientConfig) ([]*http.Cookie, string, error) {
327
type CreateSessionPayload struct {
328
AuthFlowResult
329
OrganizationID string `json:"organizationId"`
330
ClientConfigID string `json:"oidcClientConfigId"`
331
}
332
sessionPayload := CreateSessionPayload{
333
AuthFlowResult: *flowResult,
334
OrganizationID: clientConfig.OrganizationID,
335
ClientConfigID: clientConfig.ID,
336
}
337
payload, err := json.Marshal(sessionPayload)
338
if err != nil {
339
return nil, "", err
340
}
341
342
url := fmt.Sprintf("http://%s/session", s.sessionServiceAddress)
343
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
344
if err != nil {
345
return nil, "", fmt.Errorf("failed to construct session request: %w", err)
346
}
347
req.Header.Set("Content-Type", "application/json")
348
349
res, err := http.DefaultClient.Do(req)
350
if err != nil {
351
return nil, "", fmt.Errorf("failed to make request to /session endpoint: %w", err)
352
}
353
354
body, err := io.ReadAll(res.Body)
355
if err != nil {
356
return nil, "", err
357
}
358
message := string(body)
359
360
if res.StatusCode == http.StatusOK {
361
return res.Cookies(), message, nil
362
}
363
364
log.WithField("create-session-error", message).Error("Failed to create session (via server)")
365
return nil, message, fmt.Errorf("unexpected status code: %v", res.StatusCode)
366
}
367
368
func (s *Service) validateRequiredClaims(ctx context.Context, provider *oidc.Provider, token *goidc.IDToken) (jwt.MapClaims, error) {
369
if len(token.Audience) < 1 {
370
return nil, fmt.Errorf("audience claim is missing")
371
}
372
var claims jwt.MapClaims
373
err := token.Claims(&claims)
374
if err != nil {
375
return nil, fmt.Errorf("failed to unmarshal claims of ID token: %w", err)
376
}
377
requiredClaims := []string{"email", "name"}
378
missingClaims := []string{}
379
for _, claim := range requiredClaims {
380
if _, ok := claims[claim]; !ok {
381
missingClaims = append(missingClaims, claim)
382
}
383
}
384
if len(missingClaims) > 0 {
385
err = s.fillClaims(ctx, provider, claims, missingClaims)
386
if err != nil {
387
log.WithError(err).Error("failed to fill claims")
388
}
389
// continue
390
}
391
for _, claim := range requiredClaims {
392
if _, ok := claims[claim]; !ok {
393
return nil, fmt.Errorf("%s claim is missing", claim)
394
}
395
}
396
return claims, nil
397
}
398
399
func (s *Service) verifyCelExpression(ctx context.Context, celExpression string, claims jwt.MapClaims) (bool, error) {
400
if celExpression == "" {
401
return true, nil
402
}
403
env, err := cel.NewEnv(cel.Declarations(decls.NewVar("claims", decls.NewMapType(decls.String, decls.Dyn))))
404
if err != nil {
405
return false, &CelExprError{Msg: fmt.Errorf("failed to create claims env: %w", err).Error(), Code: "CEL:INVALIDATE"}
406
}
407
ast, issues := env.Compile(celExpression)
408
if issues != nil {
409
if issues.Err() != nil {
410
return false, &CelExprError{Msg: fmt.Errorf("failed to compile CEL Expression: %w", issues.Err()).Error(), Code: "CEL:INVALIDATE"}
411
}
412
// should not happen
413
log.WithField("issues", issues).Error("failed to compile CEL Expression")
414
return false, &CelExprError{Msg: fmt.Errorf("failed to compile CEL Expression").Error(), Code: "CEL:INVALIDATE"}
415
}
416
prg, err := env.Program(ast)
417
if err != nil {
418
log.WithError(err).Error("failed to create CEL program")
419
return false, &CelExprError{Msg: fmt.Errorf("failed to create CEL program").Error(), Code: "CEL:INVALIDATE"}
420
}
421
input := map[string]interface{}{
422
"claims": claims,
423
}
424
val, _, err := prg.ContextEval(ctx, input)
425
if err != nil {
426
return false, &CelExprError{Msg: fmt.Errorf("failed to evaluate CEL program: %w", err).Error(), Code: "CEL:EVAL_ERR"}
427
}
428
result, ok := val.Value().(bool)
429
if !ok {
430
return false, &CelExprError{Msg: fmt.Errorf("CEL Expression did not evaluate to a boolean").Error(), Code: "CEL:EVAL_NOT_BOOL"}
431
}
432
if !result {
433
return false, &CelExprError{Msg: fmt.Errorf("CEL Expression did not evaluate to true").Error(), Code: "CEL:EVAL_FALSE"}
434
}
435
return result, nil
436
}
437
438
func (s *Service) fillClaims(ctx context.Context, provider *oidc.Provider, claims jwt.MapClaims, missingClaims []string) error {
439
oauth2Info := GetOAuth2ResultFromContext(ctx)
440
if oauth2Info == nil {
441
return fmt.Errorf("oauth2 info not found")
442
}
443
userinfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Info.OAuth2Token))
444
if err != nil {
445
return fmt.Errorf("failed to get userinfo: %w", err)
446
}
447
var userinfoClaims map[string]interface{}
448
if err := userinfo.Claims(&userinfoClaims); err != nil {
449
return fmt.Errorf("failed to unmarshal userinfo claims: %w", err)
450
}
451
for _, key := range missingClaims {
452
switch key {
453
case "email":
454
// check userinfo definition to get more info
455
claims["email"] = userinfo.Email
456
default:
457
if value, ok := userinfoClaims[key]; ok {
458
claims[key] = value
459
}
460
}
461
}
462
return nil
463
}
464
465