Path: blob/main/components/public-api-server/pkg/auth/middleware_test.go
2500 views
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.1// Licensed under the GNU Affero General Public License (AGPL).2// See License.AGPL.txt in the project root for license information.34package auth56import (7"context"8"fmt"9"net/http"10"net/http/httptest"11"testing"12"time"1314"github.com/bufbuild/connect-go"15"github.com/gitpod-io/gitpod/components/public-api/go/config"16"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"17"github.com/gitpod-io/gitpod/public-api-server/pkg/jws/jwstest"18"github.com/google/uuid"19"github.com/stretchr/testify/require"20)2122func TestNewServerInterceptor(t *testing.T) {23requestPayload := "request"24type TokenResponse struct {25Token string `json:"token"`26}2728type Header struct {29Key string30Value string31}3233keyset := jwstest.GenerateKeySet(t)34rsa256, err := jws.NewRSA256(keyset)35require.NoError(t, err)3637sessionCfg := config.SessionConfig{38Issuer: "unittest.com",39Cookie: config.CookieConfig{40Name: "cookie_jwt",41},42}4344handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {45token, _ := TokenFromContext(ctx)46return connect.NewResponse(&TokenResponse{Token: token.Value}), nil47})4849validJWTToken, err := rsa256.Sign(NewSessionJWT(uuid.New(), sessionCfg.Issuer, time.Now(), time.Now().Add(5*time.Minute)))50require.NoError(t, err)51expiredJWTToken, err := rsa256.Sign(NewSessionJWT(uuid.New(), sessionCfg.Issuer, time.Now(), time.Now().Add(-1*time.Minute)))52require.NoError(t, err)53invalidIssuerJWTToken, err := rsa256.Sign(NewSessionJWT(uuid.New(), "random issuer", time.Now(), time.Now().Add(-1*time.Minute)))54require.NoError(t, err)5556scenarios := []struct {57Name string5859Headers []Header6061ExpectedError error62ExpectedToken string63}{64{65Name: "no headers return Unathenticated",66Headers: nil,67ExpectedError: connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request.")),68},69{70Name: "authorization header with bearer token returns ok",71Headers: []Header{{Key: "Authorization", Value: "Bearer foo"}},72ExpectedToken: "foo",73},74{75Name: "authorization header with bearer token returns ok",76Headers: []Header{{Key: "Authorization", Value: "Bearer foo"}},77ExpectedToken: "foo",78},79{80Name: "cookie header with invalid JWT token is rejected",81Headers: []Header{{Key: "Cookie", Value: fmt.Sprintf("%s=%s", sessionCfg.Cookie.Name, "invalid_token")}},82ExpectedError: connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("JWT session could not be verified.")),83},84{85Name: "cookie header with expired JWT token is rejected",86Headers: []Header{{87Key: "Cookie",88Value: fmt.Sprintf("%s=%s", sessionCfg.Cookie.Name, expiredJWTToken)},89},90ExpectedError: connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("JWT session could not be verified.")),91},92{93Name: "cookie header with invalid issuer is rejected",94Headers: []Header{{95Key: "Cookie",96Value: fmt.Sprintf("%s=%s", sessionCfg.Cookie.Name, invalidIssuerJWTToken)},97},98ExpectedError: connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("JWT session could not be verified.")),99},100{101Name: "cookie header with valid JWT token is accepted",102Headers: []Header{{103Key: "Cookie",104Value: fmt.Sprintf("%s=%s", sessionCfg.Cookie.Name, validJWTToken),105}},106ExpectedToken: fmt.Sprintf("%s=%s", sessionCfg.Cookie.Name, validJWTToken),107},108}109110for _, s := range scenarios {111t.Run(s.Name, func(t *testing.T) {112ctx := context.Background()113request := connect.NewRequest(&requestPayload)114115for _, header := range s.Headers {116request.Header().Add(header.Key, header.Value)117}118119interceptor := NewServerInterceptor(sessionCfg, rsa256)120resp, err := interceptor.WrapUnary(handler)(ctx, request)121122require.Equal(t, s.ExpectedError, err)123if err == nil {124require.Equal(t, &TokenResponse{125Token: s.ExpectedToken,126}, resp.Any())127}128129})130}131}132133func TestNewClientInterceptor(t *testing.T) {134expectedToken := "my_token"135136tokenOnRequest := ""137// Setup a test server where we capture the token supplied, we don't actually care for the response.138srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {139fmt.Println(r.Header)140token, err := BearerTokenFromHeaders(r.Header)141require.NoError(t, err)142143// Capture the token supplied in the request so we can test for it144tokenOnRequest = token145w.WriteHeader(http.StatusNotFound)146}))147148client := connect.NewClient[any, any](http.DefaultClient, srv.URL, connect.WithInterceptors(149NewClientInterceptor(expectedToken),150))151152_, _ = client.CallUnary(context.Background(), connect.NewRequest[any](nil))153require.Equal(t, expectedToken, tokenOnRequest)154}155156157