Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
projectdiscovery
GitHub Repository: projectdiscovery/nuclei
Path: blob/dev/pkg/protocols/http/httpclientpool/clientpool.go
2846 views
1
package httpclientpool
2
3
import (
4
"context"
5
"crypto/tls"
6
"fmt"
7
"net"
8
"net/http"
9
"net/http/cookiejar"
10
"net/url"
11
"strconv"
12
"strings"
13
"sync"
14
"time"
15
16
"github.com/pkg/errors"
17
"golang.org/x/net/proxy"
18
"golang.org/x/net/publicsuffix"
19
20
"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
21
"github.com/projectdiscovery/nuclei/v3/pkg/protocols"
22
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
23
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"
24
"github.com/projectdiscovery/nuclei/v3/pkg/types"
25
"github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy"
26
"github.com/projectdiscovery/rawhttp"
27
"github.com/projectdiscovery/retryablehttp-go"
28
urlutil "github.com/projectdiscovery/utils/url"
29
)
30
31
var (
32
forceMaxRedirects int
33
)
34
35
// Init initializes the clientpool implementation
36
func Init(options *types.Options) error {
37
if options.ShouldFollowHTTPRedirects() {
38
forceMaxRedirects = options.MaxRedirects
39
}
40
41
return nil
42
}
43
44
// ConnectionConfiguration contains the custom configuration options for a connection
45
type ConnectionConfiguration struct {
46
// DisableKeepAlive of the connection
47
DisableKeepAlive bool
48
// CustomMaxTimeout is the custom timeout for the connection
49
// This overrides all other timeouts and is used for accurate time based fuzzing.
50
CustomMaxTimeout time.Duration
51
cookiejar *cookiejar.Jar
52
mu sync.RWMutex
53
}
54
55
func (cc *ConnectionConfiguration) SetCookieJar(cookiejar *cookiejar.Jar) {
56
cc.mu.Lock()
57
defer cc.mu.Unlock()
58
59
cc.cookiejar = cookiejar
60
}
61
62
func (cc *ConnectionConfiguration) GetCookieJar() *cookiejar.Jar {
63
cc.mu.RLock()
64
defer cc.mu.RUnlock()
65
66
return cc.cookiejar
67
}
68
69
func (cc *ConnectionConfiguration) HasCookieJar() bool {
70
cc.mu.RLock()
71
defer cc.mu.RUnlock()
72
73
return cc.cookiejar != nil
74
}
75
76
// Configuration contains the custom configuration options for a client
77
type Configuration struct {
78
// Threads contains the threads for the client
79
Threads int
80
// MaxRedirects is the maximum number of redirects to follow
81
MaxRedirects int
82
// NoTimeout disables http request timeout for context based usage
83
NoTimeout bool
84
// DisableCookie disables cookie reuse for the http client (cookiejar impl)
85
DisableCookie bool
86
// FollowRedirects specifies the redirects flow
87
RedirectFlow RedirectFlow
88
// Connection defines custom connection configuration
89
Connection *ConnectionConfiguration
90
// ResponseHeaderTimeout is the timeout for response body to be read from the server
91
ResponseHeaderTimeout time.Duration
92
}
93
94
func (c *Configuration) Clone() *Configuration {
95
clone := *c
96
if c.Connection != nil {
97
cloneConnection := &ConnectionConfiguration{
98
DisableKeepAlive: c.Connection.DisableKeepAlive,
99
CustomMaxTimeout: c.Connection.CustomMaxTimeout,
100
}
101
if c.Connection.HasCookieJar() {
102
cookiejar := *c.Connection.GetCookieJar()
103
cloneConnection.SetCookieJar(&cookiejar)
104
}
105
clone.Connection = cloneConnection
106
}
107
108
return &clone
109
}
110
111
// Hash returns the hash of the configuration to allow client pooling
112
func (c *Configuration) Hash() string {
113
builder := &strings.Builder{}
114
builder.Grow(16)
115
builder.WriteString("t")
116
builder.WriteString(strconv.Itoa(c.Threads))
117
builder.WriteString("m")
118
builder.WriteString(strconv.Itoa(c.MaxRedirects))
119
builder.WriteString("n")
120
builder.WriteString(strconv.FormatBool(c.NoTimeout))
121
builder.WriteString("f")
122
builder.WriteString(strconv.Itoa(int(c.RedirectFlow)))
123
builder.WriteString("r")
124
builder.WriteString(strconv.FormatBool(c.DisableCookie))
125
builder.WriteString("c")
126
builder.WriteString(strconv.FormatBool(c.Connection != nil))
127
if c.Connection != nil && c.Connection.CustomMaxTimeout > 0 {
128
builder.WriteString("k")
129
builder.WriteString(c.Connection.CustomMaxTimeout.String())
130
}
131
builder.WriteString("r")
132
builder.WriteString(strconv.FormatInt(int64(c.ResponseHeaderTimeout.Seconds()), 10))
133
hash := builder.String()
134
return hash
135
}
136
137
// HasStandardOptions checks whether the configuration requires custom settings
138
func (c *Configuration) HasStandardOptions() bool {
139
return c.Threads == 0 && c.MaxRedirects == 0 && c.RedirectFlow == DontFollowRedirect && c.DisableCookie && c.Connection == nil && !c.NoTimeout && c.ResponseHeaderTimeout == 0
140
}
141
142
// GetRawHTTP returns the rawhttp request client
143
func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client {
144
dialers := protocolstate.GetDialersWithId(options.Options.ExecutionId)
145
if dialers == nil {
146
panic("dialers not initialized for execution id: " + options.Options.ExecutionId)
147
}
148
149
// Lock the dialers to avoid a race when setting RawHTTPClient
150
dialers.Lock()
151
defer dialers.Unlock()
152
153
if dialers.RawHTTPClient != nil {
154
return dialers.RawHTTPClient
155
}
156
157
rawHttpOptionsCopy := *rawhttp.DefaultOptions
158
if options.Options.AliveHttpProxy != "" {
159
rawHttpOptionsCopy.Proxy = options.Options.AliveHttpProxy
160
} else if options.Options.AliveSocksProxy != "" {
161
rawHttpOptionsCopy.Proxy = options.Options.AliveSocksProxy
162
} else if dialers.Fastdialer != nil {
163
rawHttpOptionsCopy.FastDialer = dialers.Fastdialer
164
}
165
rawHttpOptionsCopy.Timeout = options.Options.GetTimeouts().HttpTimeout
166
dialers.RawHTTPClient = rawhttp.NewClient(&rawHttpOptionsCopy)
167
return dialers.RawHTTPClient
168
}
169
170
// Get creates or gets a client for the protocol based on custom configuration
171
func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
172
if configuration.HasStandardOptions() {
173
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
174
if dialers == nil {
175
return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)
176
}
177
return dialers.DefaultHTTPClient, nil
178
}
179
180
return wrappedGet(options, configuration)
181
}
182
183
// wrappedGet wraps a get operation without normal client check
184
func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
185
var err error
186
187
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
188
if dialers == nil {
189
return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)
190
}
191
192
hash := configuration.Hash()
193
if client, ok := dialers.HTTPClientPool.Get(hash); ok {
194
return client, nil
195
}
196
197
// Multiple Host
198
retryableHttpOptions := retryablehttp.DefaultOptionsSpraying
199
disableKeepAlives := true
200
maxIdleConns := 0
201
maxConnsPerHost := 0
202
maxIdleConnsPerHost := -1
203
// do not split given timeout into chunks for retry
204
// because this won't work on slow hosts
205
retryableHttpOptions.NoAdjustTimeout = true
206
207
if configuration.Threads > 0 || options.ScanStrategy == scanstrategy.HostSpray.String() {
208
// Single host
209
retryableHttpOptions = retryablehttp.DefaultOptionsSingle
210
disableKeepAlives = false
211
maxIdleConnsPerHost = 500
212
maxConnsPerHost = 500
213
}
214
215
retryableHttpOptions.RetryWaitMax = 10 * time.Second
216
retryableHttpOptions.RetryMax = options.Retries
217
retryableHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second
218
if configuration.ResponseHeaderTimeout > 0 && configuration.ResponseHeaderTimeout > retryableHttpOptions.Timeout {
219
retryableHttpOptions.Timeout = configuration.ResponseHeaderTimeout
220
}
221
redirectFlow := configuration.RedirectFlow
222
maxRedirects := configuration.MaxRedirects
223
224
if forceMaxRedirects > 0 {
225
// by default we enable general redirects following
226
switch {
227
case options.FollowHostRedirects:
228
redirectFlow = FollowSameHostRedirect
229
default:
230
redirectFlow = FollowAllRedirect
231
}
232
maxRedirects = forceMaxRedirects
233
}
234
if options.DisableRedirects {
235
options.FollowRedirects = false
236
options.FollowHostRedirects = false
237
redirectFlow = DontFollowRedirect
238
maxRedirects = 0
239
}
240
241
// override connection's settings if required
242
if configuration.Connection != nil {
243
disableKeepAlives = configuration.Connection.DisableKeepAlive
244
}
245
246
// Set the base TLS configuration definition
247
tlsConfig := &tls.Config{
248
Renegotiation: tls.RenegotiateOnceAsClient,
249
InsecureSkipVerify: true,
250
MinVersion: tls.VersionTLS10,
251
ClientSessionCache: tls.NewLRUClientSessionCache(1024),
252
}
253
254
if options.SNI != "" {
255
tlsConfig.ServerName = options.SNI
256
}
257
258
// Add the client certificate authentication to the request if it's configured
259
tlsConfig, err = utils.AddConfiguredClientCertToRequest(tlsConfig, options)
260
if err != nil {
261
return nil, errors.Wrap(err, "could not create client certificate")
262
}
263
264
// responseHeaderTimeout is max timeout for response headers to be read
265
responseHeaderTimeout := options.GetTimeouts().HttpResponseHeaderTimeout
266
if configuration.ResponseHeaderTimeout != 0 {
267
responseHeaderTimeout = configuration.ResponseHeaderTimeout
268
}
269
270
if responseHeaderTimeout < retryableHttpOptions.Timeout {
271
responseHeaderTimeout = retryableHttpOptions.Timeout
272
}
273
274
if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 {
275
responseHeaderTimeout = configuration.Connection.CustomMaxTimeout
276
}
277
278
transport := &http.Transport{
279
ForceAttemptHTTP2: options.ForceAttemptHTTP2,
280
DialContext: dialers.Fastdialer.Dial,
281
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
282
if options.TlsImpersonate {
283
return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
284
}
285
if options.HasClientCertificates() || options.ForceAttemptHTTP2 {
286
return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
287
}
288
return dialers.Fastdialer.DialTLS(ctx, network, addr)
289
},
290
MaxIdleConns: maxIdleConns,
291
MaxIdleConnsPerHost: maxIdleConnsPerHost,
292
MaxConnsPerHost: maxConnsPerHost,
293
TLSClientConfig: tlsConfig,
294
DisableKeepAlives: disableKeepAlives,
295
ResponseHeaderTimeout: responseHeaderTimeout,
296
}
297
298
if options.AliveHttpProxy != "" {
299
if proxyURL, err := url.Parse(options.AliveHttpProxy); err == nil {
300
transport.Proxy = http.ProxyURL(proxyURL)
301
}
302
} else if options.AliveSocksProxy != "" {
303
socksURL, proxyErr := url.Parse(options.AliveSocksProxy)
304
if proxyErr != nil {
305
return nil, proxyErr
306
}
307
308
dialer, err := proxy.FromURL(socksURL, proxy.Direct)
309
if err != nil {
310
return nil, err
311
}
312
313
dc := dialer.(interface {
314
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
315
})
316
317
transport.DialContext = dc.DialContext
318
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
319
// upgrade proxy connection to tls
320
conn, err := dc.DialContext(ctx, network, addr)
321
if err != nil {
322
return nil, err
323
}
324
if tlsConfig.ServerName == "" {
325
// addr should be in form of host:port already set from canonicalAddr
326
host, _, err := net.SplitHostPort(addr)
327
if err != nil {
328
return nil, err
329
}
330
tlsConfig.ServerName = host
331
}
332
return tls.Client(conn, tlsConfig), nil
333
}
334
}
335
336
var jar *cookiejar.Jar
337
if configuration.Connection != nil && configuration.Connection.HasCookieJar() {
338
jar = configuration.Connection.GetCookieJar()
339
} else if !configuration.DisableCookie {
340
if jar, err = cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}); err != nil {
341
return nil, errors.Wrap(err, "could not create cookiejar")
342
}
343
}
344
345
httpclient := &http.Client{
346
Transport: transport,
347
CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),
348
}
349
if !configuration.NoTimeout {
350
httpclient.Timeout = options.GetTimeouts().HttpTimeout
351
if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 {
352
httpclient.Timeout = configuration.Connection.CustomMaxTimeout
353
}
354
}
355
client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)
356
if jar != nil {
357
client.HTTPClient.Jar = jar
358
}
359
client.CheckRetry = retryablehttp.HostSprayRetryPolicy()
360
361
// Only add to client pool if we don't have a cookie jar in place.
362
if jar == nil {
363
if err := dialers.HTTPClientPool.Set(hash, client); err != nil {
364
return nil, err
365
}
366
}
367
return client, nil
368
}
369
370
type RedirectFlow uint8
371
372
const (
373
DontFollowRedirect RedirectFlow = iota
374
FollowSameHostRedirect
375
FollowAllRedirect
376
)
377
378
const defaultMaxRedirects = 10
379
380
type checkRedirectFunc func(req *http.Request, via []*http.Request) error
381
382
func makeCheckRedirectFunc(redirectType RedirectFlow, maxRedirects int) checkRedirectFunc {
383
return func(req *http.Request, via []*http.Request) error {
384
switch redirectType {
385
case DontFollowRedirect:
386
return http.ErrUseLastResponse
387
case FollowSameHostRedirect:
388
var newHost = req.URL.Host
389
var oldHost = via[0].Host
390
if oldHost == "" {
391
oldHost = via[0].URL.Host
392
}
393
if newHost != oldHost {
394
// Tell the http client to not follow redirect
395
return http.ErrUseLastResponse
396
}
397
return checkMaxRedirects(req, via, maxRedirects)
398
case FollowAllRedirect:
399
return checkMaxRedirects(req, via, maxRedirects)
400
}
401
return nil
402
}
403
}
404
405
func checkMaxRedirects(req *http.Request, via []*http.Request, maxRedirects int) error {
406
if maxRedirects == 0 {
407
if len(via) > defaultMaxRedirects {
408
return http.ErrUseLastResponse
409
}
410
return nil
411
}
412
413
if len(via) > maxRedirects {
414
return http.ErrUseLastResponse
415
}
416
417
// NOTE(dwisiswant0): rebuild request URL. See #5900.
418
if u := req.URL.String(); !isURLEncoded(u) {
419
parsed, err := urlutil.Parse(u)
420
if err != nil {
421
return fmt.Errorf("%w: %w", ErrRebuildURL, err)
422
}
423
424
req.URL = parsed.URL
425
}
426
427
return nil
428
}
429
430
// isURLEncoded is an helper function to check if the URL is already encoded
431
//
432
// NOTE(dwisiswant0): shall we move this under `projectdiscovery/utils/urlutil`?
433
func isURLEncoded(s string) bool {
434
decoded, err := url.QueryUnescape(s)
435
if err != nil {
436
// If decoding fails, it may indicate a malformed URL/invalid encoding.
437
return false
438
}
439
440
return decoded != s
441
}
442
443