Path: blob/main/components/public-api-server/pkg/proxy/conn_test.go
2500 views
// Copyright (c) 2022 Gitpod GmbH. All rights reserved.1// Licensed under the GNU Affero General Public License (AGPL).2// See License.AGPL.txt in the project root for license information.34package proxy56import (7"context"8"net/url"9"testing"1011gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"12"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"13"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"14"github.com/golang/mock/gomock"15lru "github.com/hashicorp/golang-lru"16"github.com/stretchr/testify/require"17)1819func TestConnectionPool(t *testing.T) {20ctrl := gomock.NewController(t)21defer ctrl.Finish()22srv := gitpod.NewMockAPIInterface(ctrl)2324cache, err := lru.New(2)25require.NoError(t, err)26pool := &ConnectionPool{27cache: cache,28connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {29return srv, nil30},31}3233fooToken := auth.NewAccessToken("foo")34barToken := auth.NewAccessToken("bar")35bazToken := auth.NewAccessToken("baz")3637_, err = pool.Get(context.Background(), fooToken)38require.NoError(t, err)39require.Equal(t, 1, pool.cache.Len())4041_, err = pool.Get(context.Background(), barToken)42require.NoError(t, err)43require.Equal(t, 2, pool.cache.Len())4445_, err = pool.Get(context.Background(), bazToken)46require.NoError(t, err)47require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons")48require.True(t, pool.cache.Contains(pool.cacheKey(barToken, "")))49require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, "")))50}5152func TestConnectionPool_ByDistinctOrigins(t *testing.T) {53ctrl := gomock.NewController(t)54defer ctrl.Finish()55srv := gitpod.NewMockAPIInterface(ctrl)5657cache, err := lru.New(2)58require.NoError(t, err)59pool := &ConnectionPool{60cache: cache,61connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {62return srv, nil63},64}6566token := auth.NewAccessToken("foo")6768ctxWithOriginA := origin.ToContext(context.Background(), "originA")69ctxWithOriginB := origin.ToContext(context.Background(), "originB")7071_, err = pool.Get(ctxWithOriginA, token)72require.NoError(t, err)73require.Equal(t, 1, pool.cache.Len())7475_, err = pool.Get(ctxWithOriginB, token)76require.NoError(t, err)77require.Equal(t, 2, pool.cache.Len())78}7980func TestEndpointBasedOnToken(t *testing.T) {81u, err := url.Parse("wss://server:3000")82require.NoError(t, err)8384endpointForAccessToken, err := getEndpointBasedOnToken(auth.NewAccessToken("foo"), u)85require.NoError(t, err)86require.Equal(t, "wss://server:3000/v1", endpointForAccessToken)8788endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo"), u)89require.NoError(t, err)90require.Equal(t, "wss://server:3000/gitpod", endpointForCookie)91}929394