Path: blob/dev/pkg/protocols/http/httpclientpool/clientpool.go
2073 views
package httpclientpool12import (3"context"4"crypto/tls"5"fmt"6"net"7"net/http"8"net/http/cookiejar"9"net/url"10"strconv"11"strings"12"sync"13"time"1415"github.com/pkg/errors"16"golang.org/x/net/proxy"17"golang.org/x/net/publicsuffix"1819"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"20"github.com/projectdiscovery/nuclei/v3/pkg/protocols"21"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"22"github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"23"github.com/projectdiscovery/nuclei/v3/pkg/types"24"github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy"25"github.com/projectdiscovery/rawhttp"26"github.com/projectdiscovery/retryablehttp-go"27urlutil "github.com/projectdiscovery/utils/url"28)2930var (31forceMaxRedirects int32)3334// Init initializes the clientpool implementation35func Init(options *types.Options) error {36if options.ShouldFollowHTTPRedirects() {37forceMaxRedirects = options.MaxRedirects38}3940return nil41}4243// ConnectionConfiguration contains the custom configuration options for a connection44type ConnectionConfiguration struct {45// DisableKeepAlive of the connection46DisableKeepAlive bool47// CustomMaxTimeout is the custom timeout for the connection48// This overrides all other timeouts and is used for accurate time based fuzzing.49CustomMaxTimeout time.Duration50cookiejar *cookiejar.Jar51mu sync.RWMutex52}5354func (cc *ConnectionConfiguration) SetCookieJar(cookiejar *cookiejar.Jar) {55cc.mu.Lock()56defer cc.mu.Unlock()5758cc.cookiejar = cookiejar59}6061func (cc *ConnectionConfiguration) GetCookieJar() *cookiejar.Jar {62cc.mu.RLock()63defer cc.mu.RUnlock()6465return cc.cookiejar66}6768func (cc *ConnectionConfiguration) HasCookieJar() bool {69cc.mu.RLock()70defer cc.mu.RUnlock()7172return cc.cookiejar != nil73}7475// Configuration contains the custom configuration options for a client76type Configuration struct {77// Threads contains the threads for the client78Threads int79// MaxRedirects is the maximum number of redirects to follow80MaxRedirects int81// NoTimeout disables http request timeout for context based usage82NoTimeout bool83// DisableCookie disables cookie reuse for the http client (cookiejar impl)84DisableCookie bool85// FollowRedirects specifies the redirects flow86RedirectFlow RedirectFlow87// Connection defines custom connection configuration88Connection *ConnectionConfiguration89// ResponseHeaderTimeout is the timeout for response body to be read from the server90ResponseHeaderTimeout time.Duration91}9293func (c *Configuration) Clone() *Configuration {94clone := *c95if c.Connection != nil {96cloneConnection := &ConnectionConfiguration{97DisableKeepAlive: c.Connection.DisableKeepAlive,98CustomMaxTimeout: c.Connection.CustomMaxTimeout,99}100if c.Connection.HasCookieJar() {101cookiejar := *c.Connection.GetCookieJar()102cloneConnection.SetCookieJar(&cookiejar)103}104clone.Connection = cloneConnection105}106107return &clone108}109110// Hash returns the hash of the configuration to allow client pooling111func (c *Configuration) Hash() string {112builder := &strings.Builder{}113builder.Grow(16)114builder.WriteString("t")115builder.WriteString(strconv.Itoa(c.Threads))116builder.WriteString("m")117builder.WriteString(strconv.Itoa(c.MaxRedirects))118builder.WriteString("n")119builder.WriteString(strconv.FormatBool(c.NoTimeout))120builder.WriteString("f")121builder.WriteString(strconv.Itoa(int(c.RedirectFlow)))122builder.WriteString("r")123builder.WriteString(strconv.FormatBool(c.DisableCookie))124builder.WriteString("c")125builder.WriteString(strconv.FormatBool(c.Connection != nil))126if c.Connection != nil && c.Connection.CustomMaxTimeout > 0 {127builder.WriteString("k")128builder.WriteString(c.Connection.CustomMaxTimeout.String())129}130builder.WriteString("r")131builder.WriteString(strconv.FormatInt(int64(c.ResponseHeaderTimeout.Seconds()), 10))132hash := builder.String()133return hash134}135136// HasStandardOptions checks whether the configuration requires custom settings137func (c *Configuration) HasStandardOptions() bool {138return c.Threads == 0 && c.MaxRedirects == 0 && c.RedirectFlow == DontFollowRedirect && c.DisableCookie && c.Connection == nil && !c.NoTimeout && c.ResponseHeaderTimeout == 0139}140141// GetRawHTTP returns the rawhttp request client142func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client {143dialers := protocolstate.GetDialersWithId(options.Options.ExecutionId)144if dialers == nil {145panic("dialers not initialized for execution id: " + options.Options.ExecutionId)146}147148// Lock the dialers to avoid a race when setting RawHTTPClient149dialers.Lock()150defer dialers.Unlock()151152if dialers.RawHTTPClient != nil {153return dialers.RawHTTPClient154}155156rawHttpOptionsCopy := *rawhttp.DefaultOptions157if options.Options.AliveHttpProxy != "" {158rawHttpOptionsCopy.Proxy = options.Options.AliveHttpProxy159} else if options.Options.AliveSocksProxy != "" {160rawHttpOptionsCopy.Proxy = options.Options.AliveSocksProxy161} else if dialers.Fastdialer != nil {162rawHttpOptionsCopy.FastDialer = dialers.Fastdialer163}164rawHttpOptionsCopy.Timeout = options.Options.GetTimeouts().HttpTimeout165dialers.RawHTTPClient = rawhttp.NewClient(&rawHttpOptionsCopy)166return dialers.RawHTTPClient167}168169// Get creates or gets a client for the protocol based on custom configuration170func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {171if configuration.HasStandardOptions() {172dialers := protocolstate.GetDialersWithId(options.ExecutionId)173if dialers == nil {174return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)175}176return dialers.DefaultHTTPClient, nil177}178179return wrappedGet(options, configuration)180}181182// wrappedGet wraps a get operation without normal client check183func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {184var err error185186dialers := protocolstate.GetDialersWithId(options.ExecutionId)187if dialers == nil {188return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)189}190191hash := configuration.Hash()192if client, ok := dialers.HTTPClientPool.Get(hash); ok {193return client, nil194}195196// Multiple Host197retryableHttpOptions := retryablehttp.DefaultOptionsSpraying198disableKeepAlives := true199maxIdleConns := 0200maxConnsPerHost := 0201maxIdleConnsPerHost := -1202// do not split given timeout into chunks for retry203// because this won't work on slow hosts204retryableHttpOptions.NoAdjustTimeout = true205206if configuration.Threads > 0 || options.ScanStrategy == scanstrategy.HostSpray.String() {207// Single host208retryableHttpOptions = retryablehttp.DefaultOptionsSingle209disableKeepAlives = false210maxIdleConnsPerHost = 500211maxConnsPerHost = 500212}213214retryableHttpOptions.RetryWaitMax = 10 * time.Second215retryableHttpOptions.RetryMax = options.Retries216redirectFlow := configuration.RedirectFlow217maxRedirects := configuration.MaxRedirects218219if forceMaxRedirects > 0 {220// by default we enable general redirects following221switch {222case options.FollowHostRedirects:223redirectFlow = FollowSameHostRedirect224default:225redirectFlow = FollowAllRedirect226}227maxRedirects = forceMaxRedirects228}229if options.DisableRedirects {230options.FollowRedirects = false231options.FollowHostRedirects = false232redirectFlow = DontFollowRedirect233maxRedirects = 0234}235236// override connection's settings if required237if configuration.Connection != nil {238disableKeepAlives = configuration.Connection.DisableKeepAlive239}240241// Set the base TLS configuration definition242tlsConfig := &tls.Config{243Renegotiation: tls.RenegotiateOnceAsClient,244InsecureSkipVerify: true,245MinVersion: tls.VersionTLS10,246}247248if options.SNI != "" {249tlsConfig.ServerName = options.SNI250}251252// Add the client certificate authentication to the request if it's configured253tlsConfig, err = utils.AddConfiguredClientCertToRequest(tlsConfig, options)254if err != nil {255return nil, errors.Wrap(err, "could not create client certificate")256}257258// responseHeaderTimeout is max timeout for response headers to be read259responseHeaderTimeout := options.GetTimeouts().HttpResponseHeaderTimeout260if configuration.ResponseHeaderTimeout != 0 {261responseHeaderTimeout = configuration.ResponseHeaderTimeout262}263if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 {264responseHeaderTimeout = configuration.Connection.CustomMaxTimeout265}266267transport := &http.Transport{268ForceAttemptHTTP2: options.ForceAttemptHTTP2,269DialContext: dialers.Fastdialer.Dial,270DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {271if options.TlsImpersonate {272return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)273}274if options.HasClientCertificates() || options.ForceAttemptHTTP2 {275return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)276}277return dialers.Fastdialer.DialTLS(ctx, network, addr)278},279MaxIdleConns: maxIdleConns,280MaxIdleConnsPerHost: maxIdleConnsPerHost,281MaxConnsPerHost: maxConnsPerHost,282TLSClientConfig: tlsConfig,283DisableKeepAlives: disableKeepAlives,284ResponseHeaderTimeout: responseHeaderTimeout,285}286287if options.AliveHttpProxy != "" {288if proxyURL, err := url.Parse(options.AliveHttpProxy); err == nil {289transport.Proxy = http.ProxyURL(proxyURL)290}291} else if options.AliveSocksProxy != "" {292socksURL, proxyErr := url.Parse(options.AliveSocksProxy)293if proxyErr != nil {294return nil, proxyErr295}296297dialer, err := proxy.FromURL(socksURL, proxy.Direct)298if err != nil {299return nil, err300}301302dc := dialer.(interface {303DialContext(ctx context.Context, network, addr string) (net.Conn, error)304})305306transport.DialContext = dc.DialContext307transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {308// upgrade proxy connection to tls309conn, err := dc.DialContext(ctx, network, addr)310if err != nil {311return nil, err312}313if tlsConfig.ServerName == "" {314// addr should be in form of host:port already set from canonicalAddr315host, _, err := net.SplitHostPort(addr)316if err != nil {317return nil, err318}319tlsConfig.ServerName = host320}321return tls.Client(conn, tlsConfig), nil322}323}324325var jar *cookiejar.Jar326if configuration.Connection != nil && configuration.Connection.HasCookieJar() {327jar = configuration.Connection.GetCookieJar()328} else if !configuration.DisableCookie {329if jar, err = cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}); err != nil {330return nil, errors.Wrap(err, "could not create cookiejar")331}332}333334httpclient := &http.Client{335Transport: transport,336CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),337}338if !configuration.NoTimeout {339httpclient.Timeout = options.GetTimeouts().HttpTimeout340if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 {341httpclient.Timeout = configuration.Connection.CustomMaxTimeout342}343}344client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)345if jar != nil {346client.HTTPClient.Jar = jar347}348client.CheckRetry = retryablehttp.HostSprayRetryPolicy()349350// Only add to client pool if we don't have a cookie jar in place.351if jar == nil {352if err := dialers.HTTPClientPool.Set(hash, client); err != nil {353return nil, err354}355}356return client, nil357}358359type RedirectFlow uint8360361const (362DontFollowRedirect RedirectFlow = iota363FollowSameHostRedirect364FollowAllRedirect365)366367const defaultMaxRedirects = 10368369type checkRedirectFunc func(req *http.Request, via []*http.Request) error370371func makeCheckRedirectFunc(redirectType RedirectFlow, maxRedirects int) checkRedirectFunc {372return func(req *http.Request, via []*http.Request) error {373switch redirectType {374case DontFollowRedirect:375return http.ErrUseLastResponse376case FollowSameHostRedirect:377var newHost = req.URL.Host378var oldHost = via[0].Host379if oldHost == "" {380oldHost = via[0].URL.Host381}382if newHost != oldHost {383// Tell the http client to not follow redirect384return http.ErrUseLastResponse385}386return checkMaxRedirects(req, via, maxRedirects)387case FollowAllRedirect:388return checkMaxRedirects(req, via, maxRedirects)389}390return nil391}392}393394func checkMaxRedirects(req *http.Request, via []*http.Request, maxRedirects int) error {395if maxRedirects == 0 {396if len(via) > defaultMaxRedirects {397return http.ErrUseLastResponse398}399return nil400}401402if len(via) > maxRedirects {403return http.ErrUseLastResponse404}405406// NOTE(dwisiswant0): rebuild request URL. See #5900.407if u := req.URL.String(); !isURLEncoded(u) {408parsed, err := urlutil.Parse(u)409if err != nil {410return fmt.Errorf("%w: %w", ErrRebuildURL, err)411}412413req.URL = parsed.URL414}415416return nil417}418419// isURLEncoded is an helper function to check if the URL is already encoded420//421// NOTE(dwisiswant0): shall we move this under `projectdiscovery/utils/urlutil`?422func isURLEncoded(s string) bool {423decoded, err := url.QueryUnescape(s)424if err != nil {425// If decoding fails, it may indicate a malformed URL/invalid encoding.426return false427}428429return decoded != s430}431432433