Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
projectdiscovery
GitHub Repository: projectdiscovery/nuclei
Path: blob/dev/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go
2072 views
1
package hosterrorscache
2
3
import (
4
"context"
5
"errors"
6
"sync"
7
"sync/atomic"
8
"testing"
9
10
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
11
"github.com/stretchr/testify/require"
12
)
13
14
const (
15
protoType = "http"
16
)
17
18
func TestCacheCheck(t *testing.T) {
19
cache := New(3, DefaultMaxHostsCount, nil)
20
err := errors.New("net/http: timeout awaiting response headers")
21
22
t.Run("increment host error", func(t *testing.T) {
23
ctx := newCtxArgs(t.Name())
24
for i := 1; i < 3; i++ {
25
cache.MarkFailed(protoType, ctx, err)
26
got := cache.Check(protoType, ctx)
27
require.Falsef(t, got, "got %v in iteration %d", got, i)
28
}
29
})
30
31
t.Run("flagged", func(t *testing.T) {
32
ctx := newCtxArgs(t.Name())
33
for i := 1; i <= 3; i++ {
34
cache.MarkFailed(protoType, ctx, err)
35
}
36
37
got := cache.Check(protoType, ctx)
38
require.True(t, got)
39
})
40
41
t.Run("mark failed or remove", func(t *testing.T) {
42
ctx := newCtxArgs(t.Name())
43
cache.MarkFailedOrRemove(protoType, ctx, nil) // nil error should remove the host from cache
44
got := cache.Check(protoType, ctx)
45
require.False(t, got)
46
})
47
}
48
49
func TestTrackErrors(t *testing.T) {
50
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})
51
52
for i := 0; i < 100; i++ {
53
cache.MarkFailed(protoType, newCtxArgs("custom"), errors.New("got: nested: custom error"))
54
got := cache.Check(protoType, newCtxArgs("custom"))
55
if i < 2 {
56
// till 3 the host is not flagged to skip
57
require.False(t, got)
58
} else {
59
// above 3 it must remain flagged to skip
60
require.True(t, got)
61
}
62
}
63
value := cache.Check(protoType, newCtxArgs("custom"))
64
require.Equal(t, true, value, "could not get checked value")
65
}
66
67
func TestCacheItemDo(t *testing.T) {
68
var (
69
count int
70
item cacheItem
71
)
72
73
wg := sync.WaitGroup{}
74
for i := 0; i < 100; i++ {
75
wg.Add(1)
76
go func() {
77
defer wg.Done()
78
item.Do(func() {
79
count++
80
})
81
}()
82
}
83
wg.Wait()
84
85
// ensures the increment happened only once regardless of the multiple call
86
require.Equal(t, count, 1)
87
}
88
89
func TestRemove(t *testing.T) {
90
cache := New(3, DefaultMaxHostsCount, nil)
91
ctx := newCtxArgs(t.Name())
92
err := errors.New("net/http: timeout awaiting response headers")
93
94
for i := 0; i < 100; i++ {
95
cache.MarkFailed(protoType, ctx, err)
96
}
97
98
require.True(t, cache.Check(protoType, ctx))
99
cache.Remove(ctx)
100
require.False(t, cache.Check(protoType, ctx))
101
}
102
103
func TestCacheMarkFailed(t *testing.T) {
104
cache := New(3, DefaultMaxHostsCount, nil)
105
106
tests := []struct {
107
host string
108
expected int32
109
}{
110
{"http://example.com:80", 1},
111
{"example.com:80", 2},
112
// earlier if port is not provided then port was omitted
113
// but from now it will default to appropriate http scheme based port with 80 as default
114
{"example.com:443", 1},
115
}
116
117
for _, test := range tests {
118
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
119
cache.MarkFailed(protoType, newCtxArgs(test.host), errors.New("no address found for host"))
120
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
121
require.Nil(t, err)
122
require.NotNil(t, failedTarget)
123
124
require.EqualValues(t, test.expected, failedTarget.errors.Load())
125
}
126
}
127
128
func TestCacheMarkFailedConcurrent(t *testing.T) {
129
cache := New(3, DefaultMaxHostsCount, nil)
130
131
tests := []struct {
132
host string
133
expected int32
134
}{
135
{"http://example.com:80", 200},
136
{"example.com:80", 200},
137
{"example.com:443", 100},
138
}
139
140
// the cache is not atomic during items creation, so we pre-create them with counter to zero
141
for _, test := range tests {
142
normalizedValue := cache.NormalizeCacheValue(test.host)
143
newItem := &cacheItem{errors: atomic.Int32{}}
144
newItem.errors.Store(0)
145
_ = cache.failedTargets.Set(normalizedValue, newItem)
146
}
147
148
wg := sync.WaitGroup{}
149
for _, test := range tests {
150
currentTest := test
151
for i := 0; i < 100; i++ {
152
wg.Add(1)
153
go func() {
154
defer wg.Done()
155
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), errors.New("net/http: timeout awaiting response headers"))
156
}()
157
}
158
}
159
wg.Wait()
160
161
for _, test := range tests {
162
require.True(t, cache.Check(protoType, newCtxArgs(test.host)))
163
164
normalizedCacheValue := cache.NormalizeCacheValue(test.host)
165
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
166
require.Nil(t, err)
167
require.NotNil(t, failedTarget)
168
169
require.EqualValues(t, test.expected, failedTarget.errors.Load())
170
}
171
}
172
173
func TestCacheCheckConcurrent(t *testing.T) {
174
cache := New(3, DefaultMaxHostsCount, nil)
175
ctx := newCtxArgs(t.Name())
176
177
wg := sync.WaitGroup{}
178
for i := 1; i <= 100; i++ {
179
wg.Add(1)
180
i := i
181
go func() {
182
defer wg.Done()
183
cache.MarkFailed(protoType, ctx, errors.New("no address found for host"))
184
if i >= 3 {
185
got := cache.Check(protoType, ctx)
186
require.True(t, got)
187
}
188
}()
189
}
190
wg.Wait()
191
}
192
193
func newCtxArgs(value string) *contextargs.Context {
194
ctx := contextargs.NewWithInput(context.TODO(), value)
195
return ctx
196
}
197
198