Path: blob/dev/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go
2072 views
package hosterrorscache12import (3"context"4"errors"5"sync"6"sync/atomic"7"testing"89"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"10"github.com/stretchr/testify/require"11)1213const (14protoType = "http"15)1617func TestCacheCheck(t *testing.T) {18cache := New(3, DefaultMaxHostsCount, nil)19err := errors.New("net/http: timeout awaiting response headers")2021t.Run("increment host error", func(t *testing.T) {22ctx := newCtxArgs(t.Name())23for i := 1; i < 3; i++ {24cache.MarkFailed(protoType, ctx, err)25got := cache.Check(protoType, ctx)26require.Falsef(t, got, "got %v in iteration %d", got, i)27}28})2930t.Run("flagged", func(t *testing.T) {31ctx := newCtxArgs(t.Name())32for i := 1; i <= 3; i++ {33cache.MarkFailed(protoType, ctx, err)34}3536got := cache.Check(protoType, ctx)37require.True(t, got)38})3940t.Run("mark failed or remove", func(t *testing.T) {41ctx := newCtxArgs(t.Name())42cache.MarkFailedOrRemove(protoType, ctx, nil) // nil error should remove the host from cache43got := cache.Check(protoType, ctx)44require.False(t, got)45})46}4748func TestTrackErrors(t *testing.T) {49cache := New(3, DefaultMaxHostsCount, []string{"custom error"})5051for i := 0; i < 100; i++ {52cache.MarkFailed(protoType, newCtxArgs("custom"), errors.New("got: nested: custom error"))53got := cache.Check(protoType, newCtxArgs("custom"))54if i < 2 {55// till 3 the host is not flagged to skip56require.False(t, got)57} else {58// above 3 it must remain flagged to skip59require.True(t, got)60}61}62value := cache.Check(protoType, newCtxArgs("custom"))63require.Equal(t, true, value, "could not get checked value")64}6566func TestCacheItemDo(t *testing.T) {67var (68count int69item cacheItem70)7172wg := sync.WaitGroup{}73for i := 0; i < 100; i++ {74wg.Add(1)75go func() {76defer wg.Done()77item.Do(func() {78count++79})80}()81}82wg.Wait()8384// ensures the increment happened only once regardless of the multiple call85require.Equal(t, count, 1)86}8788func TestRemove(t *testing.T) {89cache := New(3, DefaultMaxHostsCount, nil)90ctx := newCtxArgs(t.Name())91err := errors.New("net/http: timeout awaiting response headers")9293for i := 0; i < 100; i++ {94cache.MarkFailed(protoType, ctx, err)95}9697require.True(t, cache.Check(protoType, ctx))98cache.Remove(ctx)99require.False(t, cache.Check(protoType, ctx))100}101102func TestCacheMarkFailed(t *testing.T) {103cache := New(3, DefaultMaxHostsCount, nil)104105tests := []struct {106host string107expected int32108}{109{"http://example.com:80", 1},110{"example.com:80", 2},111// earlier if port is not provided then port was omitted112// but from now it will default to appropriate http scheme based port with 80 as default113{"example.com:443", 1},114}115116for _, test := range tests {117normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)118cache.MarkFailed(protoType, newCtxArgs(test.host), errors.New("no address found for host"))119failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)120require.Nil(t, err)121require.NotNil(t, failedTarget)122123require.EqualValues(t, test.expected, failedTarget.errors.Load())124}125}126127func TestCacheMarkFailedConcurrent(t *testing.T) {128cache := New(3, DefaultMaxHostsCount, nil)129130tests := []struct {131host string132expected int32133}{134{"http://example.com:80", 200},135{"example.com:80", 200},136{"example.com:443", 100},137}138139// the cache is not atomic during items creation, so we pre-create them with counter to zero140for _, test := range tests {141normalizedValue := cache.NormalizeCacheValue(test.host)142newItem := &cacheItem{errors: atomic.Int32{}}143newItem.errors.Store(0)144_ = cache.failedTargets.Set(normalizedValue, newItem)145}146147wg := sync.WaitGroup{}148for _, test := range tests {149currentTest := test150for i := 0; i < 100; i++ {151wg.Add(1)152go func() {153defer wg.Done()154cache.MarkFailed(protoType, newCtxArgs(currentTest.host), errors.New("net/http: timeout awaiting response headers"))155}()156}157}158wg.Wait()159160for _, test := range tests {161require.True(t, cache.Check(protoType, newCtxArgs(test.host)))162163normalizedCacheValue := cache.NormalizeCacheValue(test.host)164failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)165require.Nil(t, err)166require.NotNil(t, failedTarget)167168require.EqualValues(t, test.expected, failedTarget.errors.Load())169}170}171172func TestCacheCheckConcurrent(t *testing.T) {173cache := New(3, DefaultMaxHostsCount, nil)174ctx := newCtxArgs(t.Name())175176wg := sync.WaitGroup{}177for i := 1; i <= 100; i++ {178wg.Add(1)179i := i180go func() {181defer wg.Done()182cache.MarkFailed(protoType, ctx, errors.New("no address found for host"))183if i >= 3 {184got := cache.Check(protoType, ctx)185require.True(t, got)186}187}()188}189wg.Wait()190}191192func newCtxArgs(value string) *contextargs.Context {193ctx := contextargs.NewWithInput(context.TODO(), value)194return ctx195}196197198