Path: blob/main/components/public-api-server/pkg/oidc/service.go
2500 views
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.1// Licensed under the GNU Affero General Public License (AGPL).2// See License.AGPL.txt in the project root for license information.34package oidc56import (7"bytes"8"context"9"crypto/rand"10"encoding/base64"11"encoding/json"12"fmt"13"io"14"net/http"15"time"1617"github.com/coreos/go-oidc/v3/oidc"18goidc "github.com/coreos/go-oidc/v3/oidc"19"github.com/gitpod-io/gitpod/common-go/log"20db "github.com/gitpod-io/gitpod/components/gitpod-db/go"21"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"22"github.com/golang-jwt/jwt/v5"23"github.com/google/cel-go/cel"24"github.com/google/cel-go/checker/decls"25"github.com/google/uuid"26"golang.org/x/oauth2"27"google.golang.org/grpc/codes"28"google.golang.org/grpc/status"29"gorm.io/gorm"30)3132type Service struct {33dbConn *gorm.DB34cipher db.Cipher3536// jwts37stateExpiry time.Duration38signerVerifier jws.SignerVerifier3940sessionServiceAddress string4142// TODO(at) remove by enhancing test setups43skipVerifyIdToken bool44}4546type ClientConfig struct {47ID string48OrganizationID string49Issuer string50Active bool51OAuth2Config *oauth2.Config52VerifierConfig *goidc.Config53CelExpression string54UsePKCE bool55}5657type StartParams struct {58State string59Nonce string60CodeVerifier string61AuthCodeURL string62}6364type AuthFlowResult struct {65IDToken *goidc.IDToken `json:"idToken"`66Claims map[string]interface{} `json:"claims"`67}6869func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, signerVerifier jws.SignerVerifier, stateExpiry time.Duration) *Service {70return &Service{71sessionServiceAddress: sessionServiceAddress,7273dbConn: dbConn,74cipher: cipher,7576signerVerifier: signerVerifier,77stateExpiry: stateExpiry,78}79}8081func (s *Service) getStartParams(config *ClientConfig, redirectURL string, stateParams StateParams) (*StartParams, error) {82// the `state` is supposed to be passed through unmodified by the IdP.83state, err := s.encodeStateParam(stateParams)84if err != nil {85return nil, fmt.Errorf("failed to encode state")86}8788// number used once89nonce, err := randString(32)90if err != nil {91return nil, fmt.Errorf("failed to create nonce")92}9394// Configuring `AuthCodeOption`s, e.g. nonce95config.OAuth2Config.RedirectURL = redirectURL9697opts := []oauth2.AuthCodeOption{goidc.Nonce(nonce)}98var verifier string99if config.UsePKCE {100verifier = oauth2.GenerateVerifier()101opts = append(opts, oauth2.S256ChallengeOption(verifier))102}103104authCodeURL := config.OAuth2Config.AuthCodeURL(state, opts...)105106return &StartParams{107AuthCodeURL: authCodeURL,108State: state,109Nonce: nonce,110CodeVerifier: verifier,111}, nil112}113114func (s *Service) encodeStateParam(state StateParams) (string, error) {115now := time.Now().UTC()116expiry := now.Add(s.stateExpiry)117token := NewStateJWT(state, now, expiry)118119signed, err := s.signerVerifier.Sign(token)120if err != nil {121return "", fmt.Errorf("failed to sign jwt: %w", err)122}123return signed, nil124}125126func (s *Service) decodeStateParam(encodedToken string) (StateParams, error) {127claims := &StateClaims{}128_, err := s.signerVerifier.Verify(encodedToken, claims)129if err != nil {130return StateParams{}, fmt.Errorf("failed to verify state token: %w", err)131}132133return claims.StateParams, nil134}135136func randString(size int) (string, error) {137b := make([]byte, size)138if _, err := io.ReadFull(rand.Reader, b); err != nil {139return "", err140}141return base64.RawURLEncoding.EncodeToString(b), nil142}143144func (s *Service) getClientConfigFromStartRequest(r *http.Request) (*ClientConfig, error) {145orgSlug := r.URL.Query().Get("orgSlug")146idParam := r.URL.Query().Get("id")147148// if no org slug is given, we assume the request is for the default org149if orgSlug == "" && idParam == "" {150org, err := db.GetSingleOrganizationWithActiveSSO(r.Context(), s.dbConn)151if err != nil {152return nil, fmt.Errorf("Failed to find team: %w", err)153}154orgSlug = org.Slug155}156if orgSlug != "" {157dbEntry, err := db.GetActiveOIDCClientConfigByOrgSlug(r.Context(), s.dbConn, orgSlug)158if err != nil {159return nil, fmt.Errorf("Failed to find OIDC clients: %w", err)160}161162config, err := s.convertClientConfig(r.Context(), dbEntry)163if err != nil {164return nil, fmt.Errorf("Failed to find OIDC clients: %w", err)165}166167return &config, nil168}169170if idParam == "" {171return nil, fmt.Errorf("missing id parameter")172}173174if idParam != "" {175config, err := s.getConfigById(r.Context(), idParam)176if err != nil {177return nil, err178}179return config, nil180}181182return nil, fmt.Errorf("failed to find OIDC config")183}184185func (s *Service) getClientConfigFromCallbackRequest(r *http.Request) (*ClientConfig, *StateParams, error) {186stateParam := r.URL.Query().Get("state")187if stateParam == "" {188return nil, nil, fmt.Errorf("missing state parameter")189}190191state, err := s.decodeStateParam(stateParam)192if err != nil {193return nil, nil, fmt.Errorf("bad state param")194}195config, _ := s.getConfigById(r.Context(), state.ClientConfigID)196if config != nil {197return config, &state, nil198}199200return nil, nil, fmt.Errorf("failed to find OIDC config on callback")201}202203func (s *Service) activateAndVerifyClientConfig(ctx context.Context, config *ClientConfig) error {204uuid, err := uuid.Parse(config.ID)205if err != nil {206return err207}208err = db.VerifyClientConfig(ctx, s.dbConn, uuid)209if err != nil {210return err211}212return db.SetClientConfigActiviation(ctx, s.dbConn, uuid, true)213}214215func (s *Service) markClientConfigAsVerified(ctx context.Context, config *ClientConfig) error {216uuid, err := uuid.Parse(config.ID)217if err != nil {218return err219}220return db.VerifyClientConfig(ctx, s.dbConn, uuid)221}222223func (s *Service) getConfigById(ctx context.Context, id string) (*ClientConfig, error) {224uuid, err := uuid.Parse(id)225if err != nil {226return nil, err227}228dbEntry, err := db.GetOIDCClientConfig(ctx, s.dbConn, uuid)229if err != nil {230return nil, err231}232config, err := s.convertClientConfig(ctx, dbEntry)233if err != nil {234log.Log.WithError(err).Error("Failed to decrypt oidc client config.")235return nil, status.Errorf(codes.Internal, "Failed to decrypt OIDC client config.")236}237238return &config, nil239}240241func (s *Service) convertClientConfig(ctx context.Context, dbEntry db.OIDCClientConfig) (ClientConfig, error) {242spec, err := dbEntry.Data.Decrypt(s.cipher)243if err != nil {244log.Log.WithError(err).Error("Failed to decrypt oidc client config.")245return ClientConfig{}, status.Errorf(codes.Internal, "Failed to decrypt OIDC client config.")246}247248provider, err := oidc.NewProvider(ctx, dbEntry.Issuer)249if err != nil {250return ClientConfig{}, err251}252253return ClientConfig{254ID: dbEntry.ID.String(),255OrganizationID: dbEntry.OrganizationID.String(),256Issuer: dbEntry.Issuer,257Active: dbEntry.Active,258OAuth2Config: &oauth2.Config{259ClientID: spec.ClientID,260ClientSecret: spec.ClientSecret,261Endpoint: provider.Endpoint(),262Scopes: spec.Scopes,263},264CelExpression: spec.CelExpression,265UsePKCE: spec.UsePKCE,266VerifierConfig: &goidc.Config{267ClientID: spec.ClientID,268},269}, nil270}271272type authenticateParams struct {273Config *ClientConfig274OAuth2Result *OAuth2Result275NonceCookieValue string276}277278type CelExprError struct {279Msg string280Code string281}282283func (e *CelExprError) Error() string {284return fmt.Sprintf("%s [%s]", e.Msg, e.Code)285}286287func (s *Service) authenticate(ctx context.Context, params authenticateParams) (*AuthFlowResult, error) {288rawIDToken, ok := params.OAuth2Result.OAuth2Token.Extra("id_token").(string)289if !ok {290return nil, fmt.Errorf("id_token not found")291}292293provider, err := oidc.NewProvider(ctx, params.Config.Issuer)294if err != nil {295return nil, fmt.Errorf("Failed to initialize provider.")296}297verifier := provider.Verifier(&goidc.Config{298ClientID: params.Config.OAuth2Config.ClientID,299})300301idToken, err := verifier.Verify(ctx, rawIDToken)302if err != nil {303return nil, fmt.Errorf("failed to verify id_token: %w", err)304}305if idToken.Nonce != params.NonceCookieValue {306return nil, fmt.Errorf("nonce mismatch")307}308validatedClaims, err := s.validateRequiredClaims(ctx, provider, idToken)309if err != nil {310return nil, fmt.Errorf("failed to validate required claims: %w", err)311}312validatedCelExpression, err := s.verifyCelExpression(ctx, params.Config.CelExpression, validatedClaims)313if err != nil {314return nil, err315}316if !validatedCelExpression {317return nil, &CelExprError{Msg: "CEL expression did not evaluate to true", Code: "CEL:EVAL_FALSE"}318}319return &AuthFlowResult{320IDToken: idToken,321Claims: validatedClaims,322}, nil323}324325func (s *Service) createSession(ctx context.Context, flowResult *AuthFlowResult, clientConfig *ClientConfig) ([]*http.Cookie, string, error) {326type CreateSessionPayload struct {327AuthFlowResult328OrganizationID string `json:"organizationId"`329ClientConfigID string `json:"oidcClientConfigId"`330}331sessionPayload := CreateSessionPayload{332AuthFlowResult: *flowResult,333OrganizationID: clientConfig.OrganizationID,334ClientConfigID: clientConfig.ID,335}336payload, err := json.Marshal(sessionPayload)337if err != nil {338return nil, "", err339}340341url := fmt.Sprintf("http://%s/session", s.sessionServiceAddress)342req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))343if err != nil {344return nil, "", fmt.Errorf("failed to construct session request: %w", err)345}346req.Header.Set("Content-Type", "application/json")347348res, err := http.DefaultClient.Do(req)349if err != nil {350return nil, "", fmt.Errorf("failed to make request to /session endpoint: %w", err)351}352353body, err := io.ReadAll(res.Body)354if err != nil {355return nil, "", err356}357message := string(body)358359if res.StatusCode == http.StatusOK {360return res.Cookies(), message, nil361}362363log.WithField("create-session-error", message).Error("Failed to create session (via server)")364return nil, message, fmt.Errorf("unexpected status code: %v", res.StatusCode)365}366367func (s *Service) validateRequiredClaims(ctx context.Context, provider *oidc.Provider, token *goidc.IDToken) (jwt.MapClaims, error) {368if len(token.Audience) < 1 {369return nil, fmt.Errorf("audience claim is missing")370}371var claims jwt.MapClaims372err := token.Claims(&claims)373if err != nil {374return nil, fmt.Errorf("failed to unmarshal claims of ID token: %w", err)375}376requiredClaims := []string{"email", "name"}377missingClaims := []string{}378for _, claim := range requiredClaims {379if _, ok := claims[claim]; !ok {380missingClaims = append(missingClaims, claim)381}382}383if len(missingClaims) > 0 {384err = s.fillClaims(ctx, provider, claims, missingClaims)385if err != nil {386log.WithError(err).Error("failed to fill claims")387}388// continue389}390for _, claim := range requiredClaims {391if _, ok := claims[claim]; !ok {392return nil, fmt.Errorf("%s claim is missing", claim)393}394}395return claims, nil396}397398func (s *Service) verifyCelExpression(ctx context.Context, celExpression string, claims jwt.MapClaims) (bool, error) {399if celExpression == "" {400return true, nil401}402env, err := cel.NewEnv(cel.Declarations(decls.NewVar("claims", decls.NewMapType(decls.String, decls.Dyn))))403if err != nil {404return false, &CelExprError{Msg: fmt.Errorf("failed to create claims env: %w", err).Error(), Code: "CEL:INVALIDATE"}405}406ast, issues := env.Compile(celExpression)407if issues != nil {408if issues.Err() != nil {409return false, &CelExprError{Msg: fmt.Errorf("failed to compile CEL Expression: %w", issues.Err()).Error(), Code: "CEL:INVALIDATE"}410}411// should not happen412log.WithField("issues", issues).Error("failed to compile CEL Expression")413return false, &CelExprError{Msg: fmt.Errorf("failed to compile CEL Expression").Error(), Code: "CEL:INVALIDATE"}414}415prg, err := env.Program(ast)416if err != nil {417log.WithError(err).Error("failed to create CEL program")418return false, &CelExprError{Msg: fmt.Errorf("failed to create CEL program").Error(), Code: "CEL:INVALIDATE"}419}420input := map[string]interface{}{421"claims": claims,422}423val, _, err := prg.ContextEval(ctx, input)424if err != nil {425return false, &CelExprError{Msg: fmt.Errorf("failed to evaluate CEL program: %w", err).Error(), Code: "CEL:EVAL_ERR"}426}427result, ok := val.Value().(bool)428if !ok {429return false, &CelExprError{Msg: fmt.Errorf("CEL Expression did not evaluate to a boolean").Error(), Code: "CEL:EVAL_NOT_BOOL"}430}431if !result {432return false, &CelExprError{Msg: fmt.Errorf("CEL Expression did not evaluate to true").Error(), Code: "CEL:EVAL_FALSE"}433}434return result, nil435}436437func (s *Service) fillClaims(ctx context.Context, provider *oidc.Provider, claims jwt.MapClaims, missingClaims []string) error {438oauth2Info := GetOAuth2ResultFromContext(ctx)439if oauth2Info == nil {440return fmt.Errorf("oauth2 info not found")441}442userinfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Info.OAuth2Token))443if err != nil {444return fmt.Errorf("failed to get userinfo: %w", err)445}446var userinfoClaims map[string]interface{}447if err := userinfo.Claims(&userinfoClaims); err != nil {448return fmt.Errorf("failed to unmarshal userinfo claims: %w", err)449}450for _, key := range missingClaims {451switch key {452case "email":453// check userinfo definition to get more info454claims["email"] = userinfo.Email455default:456if value, ok := userinfoClaims[key]; ok {457claims[key] = value458}459}460}461return nil462}463464465