Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
gitpod-io
GitHub Repository: gitpod-io/gitpod
Path: blob/main/components/public-api-server/pkg/oidc/service_test.go
2500 views
1
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.
2
// Licensed under the GNU Affero General Public License (AGPL).
3
// See License.AGPL.txt in the project root for license information.
4
5
package oidc
6
7
import (
8
"context"
9
"encoding/json"
10
"fmt"
11
"io"
12
"log"
13
"net/http"
14
"net/http/httptest"
15
"net/url"
16
"testing"
17
"time"
18
19
"github.com/coreos/go-oidc/v3/oidc"
20
goidc "github.com/coreos/go-oidc/v3/oidc"
21
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
22
"github.com/gitpod-io/gitpod/components/gitpod-db/go/dbtest"
23
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
24
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws/jwstest"
25
"github.com/go-chi/chi/v5"
26
"github.com/go-chi/chi/v5/middleware"
27
"github.com/golang-jwt/jwt/v5"
28
"github.com/google/go-cmp/cmp"
29
"github.com/google/uuid"
30
"github.com/stretchr/testify/require"
31
"golang.org/x/oauth2"
32
"gopkg.in/square/go-jose.v2"
33
"gorm.io/gorm"
34
)
35
36
func TestGetStartParams(t *testing.T) {
37
const (
38
issuerG = "https://accounts.google.com"
39
clientID = "client-id-123"
40
redirectURL = "https://test.local/iam/oidc/callback"
41
)
42
service, _ := setupOIDCServiceForTests(t)
43
config := &ClientConfig{
44
Issuer: issuerG,
45
VerifierConfig: &oidc.Config{},
46
OAuth2Config: &oauth2.Config{
47
ClientID: clientID,
48
Endpoint: oauth2.Endpoint{
49
AuthURL: issuerG + "/o/oauth2/v2/auth",
50
},
51
},
52
}
53
54
params, err := service.getStartParams(config, redirectURL, StateParams{
55
ClientConfigID: config.ID,
56
ReturnToURL: "/",
57
Activate: false,
58
})
59
60
require.NoError(t, err)
61
require.NotNil(t, params.Nonce)
62
require.NotNil(t, params.State)
63
64
// AuthCodeURL example:
65
// https://accounts.google.com/o/oauth2/v2/auth
66
// ?client_id=client-id-123
67
// &nonce=UFTMxxUtc5jVZbp2a2R9XEoRwpfzs-04FcmVQ-HdCsw
68
// &response_type=code
69
// &redirect_url=https...
70
// &state=Q4XzRcdo4jtOYeRbF17T9LHHwX-4HacT1_5pZH8mXLI
71
require.NotNil(t, params.AuthCodeURL)
72
require.Contains(t, params.AuthCodeURL, issuerG)
73
require.Contains(t, params.AuthCodeURL, clientID)
74
require.Contains(t, params.AuthCodeURL, url.QueryEscape(redirectURL))
75
require.Contains(t, params.AuthCodeURL, url.QueryEscape(params.Nonce))
76
require.Contains(t, params.AuthCodeURL, url.QueryEscape(params.State))
77
}
78
79
func TestGetClientConfigFromStartRequest(t *testing.T) {
80
issuer := newFakeIdP(t)
81
service, dbConn := setupOIDCServiceForTests(t)
82
config, team := createConfig(t, dbConn, &ClientConfig{
83
Issuer: issuer,
84
Active: true,
85
VerifierConfig: &oidc.Config{},
86
OAuth2Config: &oauth2.Config{},
87
})
88
// create second org to emulate an installation with multiple orgs
89
createConfig(t, dbConn, &ClientConfig{
90
Issuer: issuer,
91
Active: true,
92
VerifierConfig: &oidc.Config{},
93
OAuth2Config: &oauth2.Config{},
94
})
95
configID := config.ID.String()
96
97
testCases := []struct {
98
Location string
99
ExpectedError bool
100
ExpectedId string
101
}{
102
{
103
Location: "/start?word=abc",
104
ExpectedError: true,
105
ExpectedId: "",
106
},
107
{
108
Location: "/start?id=UNKNOWN",
109
ExpectedError: true,
110
ExpectedId: "",
111
},
112
{
113
Location: "/start?id=" + configID,
114
ExpectedError: false,
115
ExpectedId: configID,
116
},
117
{
118
Location: "/start?orgSlug=" + team.Slug,
119
ExpectedError: false,
120
ExpectedId: configID,
121
},
122
}
123
124
for _, tc := range testCases {
125
t.Run(tc.Location, func(te *testing.T) {
126
request := httptest.NewRequest(http.MethodGet, tc.Location, nil)
127
config, err := service.getClientConfigFromStartRequest(request)
128
if tc.ExpectedError == true {
129
require.Error(te, err)
130
}
131
if tc.ExpectedError != true {
132
require.NoError(te, err)
133
require.NotNil(te, config)
134
require.Equal(te, tc.ExpectedId, config.ID)
135
}
136
})
137
}
138
139
t.Cleanup(func() {
140
require.NoError(t, dbConn.Where("slug = ?", team.Slug).Delete(&db.Organization{}).Error)
141
})
142
}
143
144
func TestGetClientConfigFromStartRequestSingleOrg(t *testing.T) {
145
issuer := newFakeIdP(t)
146
service, dbConn := setupOIDCServiceForTests(t)
147
// make sure no other organizations are in the db anymore
148
dbConn.Delete(&db.Organization{}, "1=1")
149
config, team := createConfig(t, dbConn, &ClientConfig{
150
Issuer: issuer,
151
Active: true,
152
VerifierConfig: &oidc.Config{},
153
OAuth2Config: &oauth2.Config{},
154
})
155
configID := config.ID.String()
156
157
testCases := []struct {
158
Location string
159
ExpectedError bool
160
ExpectedId string
161
}{
162
{
163
Location: "/start",
164
ExpectedError: false,
165
ExpectedId: configID,
166
},
167
{
168
Location: "/start?word=abc",
169
ExpectedError: false,
170
ExpectedId: configID,
171
},
172
{
173
Location: "/start?id=UNKNOWN",
174
ExpectedError: true,
175
ExpectedId: "",
176
},
177
{
178
Location: "/start?id=" + configID,
179
ExpectedError: false,
180
ExpectedId: configID,
181
},
182
{
183
Location: "/start?orgSlug=" + team.Slug,
184
ExpectedError: false,
185
ExpectedId: configID,
186
},
187
}
188
189
for _, tc := range testCases {
190
t.Run(tc.Location, func(te *testing.T) {
191
request := httptest.NewRequest(http.MethodGet, tc.Location, nil)
192
config, err := service.getClientConfigFromStartRequest(request)
193
if tc.ExpectedError == true {
194
require.Error(te, err)
195
}
196
if tc.ExpectedError != true {
197
require.NoError(te, err)
198
require.NotNil(te, config)
199
require.Equal(te, tc.ExpectedId, config.ID, "wrong config")
200
}
201
})
202
}
203
204
t.Cleanup(func() {
205
require.NoError(t, dbConn.Where("slug = ?", team.Slug).Delete(&db.Organization{}).Error)
206
})
207
}
208
209
func TestGetClientConfigFromCallbackRequest(t *testing.T) {
210
issuer := newFakeIdP(t)
211
service, dbConn := setupOIDCServiceForTests(t)
212
config, _ := createConfig(t, dbConn, &ClientConfig{
213
Issuer: issuer,
214
VerifierConfig: &oidc.Config{},
215
OAuth2Config: &oauth2.Config{},
216
})
217
configID := config.ID.String()
218
219
state, err := service.encodeStateParam(StateParams{
220
ClientConfigID: configID,
221
ReturnToURL: "",
222
})
223
require.NoError(t, err, "failed encode state param")
224
225
state_unknown, err := service.encodeStateParam(StateParams{
226
ClientConfigID: "UNKNOWN",
227
ReturnToURL: "",
228
})
229
require.NoError(t, err, "failed encode state param")
230
231
testCases := []struct {
232
Location string
233
ExpectedError bool
234
ExpectedId string
235
}{
236
{
237
Location: "/callback?state=BAD",
238
ExpectedError: true,
239
ExpectedId: "",
240
},
241
{
242
Location: "/callback?state=" + state_unknown,
243
ExpectedError: true,
244
ExpectedId: "",
245
},
246
{
247
Location: "/callback?state=" + state,
248
ExpectedError: false,
249
ExpectedId: configID,
250
},
251
}
252
253
for _, tc := range testCases {
254
t.Run(tc.Location, func(t *testing.T) {
255
request := httptest.NewRequest(http.MethodGet, tc.Location, nil)
256
config, _, err := service.getClientConfigFromCallbackRequest(request)
257
if tc.ExpectedError == true {
258
require.Error(t, err)
259
}
260
if tc.ExpectedError != true {
261
require.NoError(t, err)
262
require.NotNil(t, config)
263
require.Equal(t, tc.ExpectedId, config.ID)
264
}
265
})
266
}
267
}
268
269
func TestCreateSession(t *testing.T) {
270
service, _ := setupOIDCServiceForTests(t)
271
272
config := ClientConfig{
273
ID: "foo1",
274
OrganizationID: "org1",
275
}
276
277
_, message, err := service.createSession(context.Background(), &AuthFlowResult{}, &config)
278
require.NoError(t, err, "failed to create session")
279
280
got := map[string]interface{}{}
281
err = json.Unmarshal([]byte(message), &got)
282
require.NoError(t, err, "failed to parse response")
283
284
expected := map[string]interface{}{
285
"claims": nil,
286
"idToken": nil,
287
"oidcClientConfigId": config.ID,
288
"organizationId": config.OrganizationID,
289
}
290
291
if diff := cmp.Diff(expected, got); diff != "" {
292
t.Errorf("Unexpected create session payload (-want +got):\n%s", diff)
293
}
294
}
295
296
func Test_validateRequiredClaims(t *testing.T) {
297
service, _ := setupOIDCServiceForTests(t)
298
299
type data struct {
300
jwt.RegisteredClaims
301
Email string `json:"email,omitempty"`
302
Name string `json:"name,omitempty"`
303
}
304
305
testCases := []struct {
306
Label string
307
ExpectedError string
308
Claims data
309
}{
310
{
311
Label: "Required claims present",
312
ExpectedError: "",
313
Claims: data{
314
RegisteredClaims: jwt.RegisteredClaims{
315
Audience: []string{"audience"},
316
},
317
Email: "me@localhost",
318
Name: "Admin",
319
},
320
},
321
{
322
Label: "Email claim is missing",
323
ExpectedError: "email claim is missing",
324
Claims: data{
325
RegisteredClaims: jwt.RegisteredClaims{
326
Audience: []string{"audience"},
327
},
328
Name: "Admin",
329
},
330
},
331
{
332
Label: "Name claim is missing",
333
ExpectedError: "name claim is missing",
334
Claims: data{
335
RegisteredClaims: jwt.RegisteredClaims{
336
Audience: []string{"audience"},
337
},
338
Email: "admin@localhost",
339
},
340
},
341
}
342
343
for _, tc := range testCases {
344
t.Run(tc.Label, func(t *testing.T) {
345
token := createTestIDToken(t, tc.Claims)
346
347
_, err := service.validateRequiredClaims(context.Background(), nil, token)
348
if tc.ExpectedError == "" {
349
require.NoError(t, err)
350
}
351
if tc.ExpectedError != "" {
352
require.Equal(t, err.Error(), tc.ExpectedError)
353
}
354
})
355
}
356
}
357
358
func Test_verifyCelExpression(t *testing.T) {
359
service, _ := setupOIDCServiceForTests(t)
360
361
testCases := []struct {
362
Label string
363
ExpectedError bool
364
ExpectedErrorMsg string
365
ExpectedErrorCode string
366
ExpectedResult bool
367
Claims jwt.MapClaims
368
CEL string
369
}{
370
{
371
Label: "email verify",
372
ExpectedError: true,
373
ExpectedErrorMsg: "CEL Expression did not evaluate to true [CEL:EVAL_FALSE]",
374
ExpectedErrorCode: "CEL:EVAL_FALSE",
375
ExpectedResult: false,
376
Claims: jwt.MapClaims{
377
"Audience": []string{"audience"},
378
"groups_direct": []string{
379
"gitpod-team",
380
"gitpod-team2/sub_group",
381
},
382
"email": "[email protected]",
383
"email_verified": false,
384
},
385
CEL: "claims.email_verified && claims.email_verified.email.endsWith('@gitpod.io')",
386
},
387
{
388
Label: "GitLab: groups restriction",
389
ExpectedError: false,
390
ExpectedResult: true,
391
Claims: jwt.MapClaims{
392
"Audience": []string{"audience"},
393
"groups_direct": []string{
394
"gitpod-team",
395
"gitpod-team2/sub_group",
396
},
397
"email": "[email protected]",
398
"email_verified": false,
399
},
400
CEL: "(claims.email_verified && claims.email_verified.email.endsWith('@gitpod.io')) || 'gitpod-team' in claims.groups_direct",
401
},
402
{
403
Label: "GitLab: groups restriction (not allowed)",
404
ExpectedError: true,
405
ExpectedErrorMsg: "CEL Expression did not evaluate to true [CEL:EVAL_FALSE]",
406
ExpectedErrorCode: "CEL:EVAL_FALSE",
407
ExpectedResult: false,
408
Claims: jwt.MapClaims{
409
"Audience": []string{"audience"},
410
"groups_direct": []string{
411
"gitpod-team2/sub_group",
412
},
413
"email": "[email protected]",
414
"email_verified": false,
415
},
416
CEL: "(claims.email_verified && claims.email_verified.email.endsWith('@gitpod.io')) || 'gitpod-team2' in claims.groups_direct",
417
},
418
{
419
Label: "invalidate cel",
420
ExpectedError: true,
421
ExpectedErrorCode: "CEL:INVALIDATE",
422
ExpectedResult: false,
423
Claims: jwt.MapClaims{
424
"Audience": []string{"audience"},
425
"groups_direct": []string{
426
"gitpod-team",
427
"gitpod-team2/sub_group",
428
},
429
"email": "[email protected]",
430
"email_verified": false,
431
},
432
CEL: "foo",
433
},
434
}
435
436
for _, tc := range testCases {
437
t.Run(tc.Label, func(t *testing.T) {
438
result, err := service.verifyCelExpression(context.Background(), tc.CEL, tc.Claims)
439
if tc.ExpectedErrorCode != "" {
440
if celExprErr, ok := err.(*CelExprError); ok {
441
require.Equal(t, celExprErr.Code, tc.ExpectedErrorCode, "Unexpected CEL error code")
442
}
443
}
444
if !tc.ExpectedError {
445
require.NoError(t, err)
446
} else {
447
require.True(t, err != nil, "Should return error")
448
if tc.ExpectedErrorMsg != "" {
449
require.Equal(t, err.Error(), tc.ExpectedErrorMsg)
450
}
451
}
452
require.Equal(t, result, tc.ExpectedResult, "Unexpected result")
453
})
454
}
455
}
456
457
func createTestIDToken(t *testing.T, claims jwt.Claims) *goidc.IDToken {
458
t.Helper()
459
460
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
461
rawIDToken, err := token.SignedString([]byte("no-relevant-for-this-test"))
462
require.NoError(t, err)
463
464
verifier := goidc.NewVerifier("http://localhost", nil, &goidc.Config{
465
SkipIssuerCheck: true,
466
SkipClientIDCheck: true,
467
SkipExpiryCheck: true,
468
InsecureSkipSignatureCheck: true,
469
})
470
471
verifiedToken, err := verifier.Verify(context.Background(), rawIDToken)
472
require.NoError(t, err)
473
474
return verifiedToken
475
}
476
477
func setupOIDCServiceForTests(t *testing.T) (*Service, *gorm.DB) {
478
t.Helper()
479
480
dbConn := dbtest.ConnectForTests(t)
481
cipher := dbtest.CipherSet(t)
482
483
sessionServerAddress := newFakeSessionServer(t)
484
485
keyset := jwstest.GenerateKeySet(t)
486
signerVerifier := jws.NewHS256FromKeySet(keyset)
487
488
service := NewService(sessionServerAddress, dbConn, cipher, signerVerifier, 5*time.Minute)
489
service.skipVerifyIdToken = true
490
return service, dbConn
491
}
492
493
func createConfig(t *testing.T, dbConn *gorm.DB, config *ClientConfig) (db.OIDCClientConfig, db.Organization) {
494
t.Helper()
495
496
team := dbtest.CreateOrganizations(t, dbConn, db.Organization{})[0]
497
498
data, err := db.EncryptJSON(dbtest.CipherSet(t), db.OIDCSpec{
499
ClientID: config.OAuth2Config.ClientID,
500
ClientSecret: config.OAuth2Config.ClientSecret,
501
})
502
require.NoError(t, err)
503
504
created := dbtest.CreateOIDCClientConfigs(t, dbConn, db.OIDCClientConfig{
505
ID: uuid.New(),
506
OrganizationID: team.ID,
507
Issuer: config.Issuer,
508
Active: false,
509
Data: data,
510
}, db.OIDCClientConfig{
511
ID: uuid.New(),
512
OrganizationID: team.ID,
513
Issuer: config.Issuer,
514
Active: config.Active,
515
Data: data,
516
})[1]
517
518
return created, team
519
}
520
521
func newFakeSessionServer(t *testing.T) string {
522
router := chi.NewRouter()
523
ts := httptest.NewServer(router)
524
url, err := url.Parse(ts.URL)
525
if err != nil {
526
log.Fatal(err)
527
}
528
529
router.Use(middleware.Logger)
530
router.Post("/session", func(w http.ResponseWriter, r *http.Request) {
531
http.SetCookie(w, &http.Cookie{
532
Name: "test-cookie",
533
Value: "chocolate-chips",
534
Path: "/",
535
HttpOnly: true,
536
Expires: time.Now().AddDate(0, 0, 1),
537
})
538
w.WriteHeader(http.StatusOK)
539
540
// mirroring back the request body for testing
541
body, err := io.ReadAll(r.Body)
542
if err != nil {
543
body = []byte(err.Error())
544
}
545
_, err = w.Write(body)
546
if err != nil {
547
log.Fatal(err)
548
}
549
})
550
551
t.Cleanup(ts.Close)
552
return url.Host
553
}
554
555
func newFakeIdP(t *testing.T) string {
556
router := chi.NewRouter()
557
ts := httptest.NewServer(router)
558
url := ts.URL
559
560
keyset := jwstest.GenerateKeySet(t)
561
rsa256, err := jws.NewRSA256(keyset)
562
require.NoError(t, err)
563
564
type IDTokenClaims struct {
565
Nonce string `json:"nonce"`
566
Email string `json:"email"`
567
Name string `json:"name"`
568
jwt.RegisteredClaims
569
}
570
token := jwt.NewWithClaims(jwt.SigningMethodRS256, &IDTokenClaims{
571
Nonce: "111",
572
RegisteredClaims: jwt.RegisteredClaims{
573
Subject: "user-id",
574
Audience: jwt.ClaimStrings{"client-id"},
575
Issuer: url,
576
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
577
IssuedAt: jwt.NewNumericDate(time.Now()),
578
},
579
Email: "[email protected]",
580
Name: "User",
581
})
582
583
idTokenValue, err := rsa256.Sign(token)
584
require.NoError(t, err)
585
586
var jwks jose.JSONWebKeySet
587
jwks.Keys = append(jwks.Keys, jose.JSONWebKey{
588
Key: &keyset.Signing.Private.PublicKey,
589
KeyID: "0001",
590
Algorithm: string(jose.RS256),
591
})
592
keysValue, err := json.Marshal(jwks)
593
require.NoError(t, err)
594
595
router.Use(middleware.Logger)
596
router.Get("/oauth2/v3/certs", func(w http.ResponseWriter, r *http.Request) {
597
_, err := w.Write(keysValue)
598
if err != nil {
599
log.Fatal(err)
600
}
601
})
602
router.Get("/o/oauth2/v2/auth", func(w http.ResponseWriter, r *http.Request) {
603
_, err := w.Write([]byte(r.URL.RawQuery))
604
if err != nil {
605
log.Fatal(err)
606
}
607
})
608
router.Post("/token", func(w http.ResponseWriter, r *http.Request) {
609
w.Header().Add("Content-Type", "application/json")
610
_, err := w.Write([]byte(fmt.Sprintf(`{
611
"access_token": "no-token-set",
612
"id_token": "%[1]s"
613
}`, idTokenValue)))
614
if err != nil {
615
log.Fatal(err)
616
}
617
})
618
router.Get("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
619
w.Header().Add("Content-Type", "application/json")
620
_, err := w.Write([]byte(fmt.Sprintf(`{
621
"issuer": "%[1]s",
622
"authorization_endpoint": "%[1]s/o/oauth2/v2/auth",
623
"device_authorization_endpoint": "%[1]s/device/code",
624
"token_endpoint": "%[1]s/token",
625
"userinfo_endpoint": "%[1]s/v1/userinfo",
626
"revocation_endpoint": "%[1]s/revoke",
627
"jwks_uri": "%[1]s/oauth2/v3/certs",
628
"response_types_supported": [
629
"code",
630
"token",
631
"id_token",
632
"code token",
633
"code id_token",
634
"token id_token",
635
"code token id_token",
636
"none"
637
],
638
"subject_types_supported": [
639
"public"
640
],
641
"id_token_signing_alg_values_supported": [
642
"RS256"
643
],
644
"scopes_supported": [
645
"openid",
646
"email",
647
"profile"
648
],
649
"token_endpoint_auth_methods_supported": [
650
"client_secret_post",
651
"client_secret_basic"
652
],
653
"claims_supported": [
654
"aud",
655
"email",
656
"email_verified",
657
"exp",
658
"family_name",
659
"given_name",
660
"iat",
661
"iss",
662
"locale",
663
"name",
664
"picture",
665
"sub"
666
],
667
"code_challenge_methods_supported": [
668
"plain",
669
"S256"
670
],
671
"grant_types_supported": [
672
"authorization_code",
673
"refresh_token",
674
"urn:ietf:params:oauth:grant-type:device_code",
675
"urn:ietf:params:oauth:grant-type:jwt-bearer"
676
]
677
}`, url)))
678
if err != nil {
679
t.Error((err))
680
t.FailNow()
681
}
682
})
683
684
t.Cleanup(ts.Close)
685
return url
686
}
687
688