Path: blob/dev/pkg/protocols/http/httpclientpool/clientpool.go
2846 views
package httpclientpool12import (3"context"4"crypto/tls"5"fmt"6"net"7"net/http"8"net/http/cookiejar"9"net/url"10"strconv"11"strings"12"sync"13"time"1415"github.com/pkg/errors"16"golang.org/x/net/proxy"17"golang.org/x/net/publicsuffix"1819"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"20"github.com/projectdiscovery/nuclei/v3/pkg/protocols"21"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"22"github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"23"github.com/projectdiscovery/nuclei/v3/pkg/types"24"github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy"25"github.com/projectdiscovery/rawhttp"26"github.com/projectdiscovery/retryablehttp-go"27urlutil "github.com/projectdiscovery/utils/url"28)2930var (31forceMaxRedirects int32)3334// Init initializes the clientpool implementation35func Init(options *types.Options) error {36if options.ShouldFollowHTTPRedirects() {37forceMaxRedirects = options.MaxRedirects38}3940return nil41}4243// ConnectionConfiguration contains the custom configuration options for a connection44type ConnectionConfiguration struct {45// DisableKeepAlive of the connection46DisableKeepAlive bool47// CustomMaxTimeout is the custom timeout for the connection48// This overrides all other timeouts and is used for accurate time based fuzzing.49CustomMaxTimeout time.Duration50cookiejar *cookiejar.Jar51mu sync.RWMutex52}5354func (cc *ConnectionConfiguration) SetCookieJar(cookiejar *cookiejar.Jar) {55cc.mu.Lock()56defer cc.mu.Unlock()5758cc.cookiejar = cookiejar59}6061func (cc *ConnectionConfiguration) GetCookieJar() *cookiejar.Jar {62cc.mu.RLock()63defer cc.mu.RUnlock()6465return cc.cookiejar66}6768func (cc *ConnectionConfiguration) HasCookieJar() bool {69cc.mu.RLock()70defer cc.mu.RUnlock()7172return cc.cookiejar != nil73}7475// Configuration contains the custom configuration options for a client76type Configuration struct {77// Threads contains the threads for the client78Threads int79// MaxRedirects is the maximum number of redirects to follow80MaxRedirects int81// NoTimeout disables http request timeout for context based usage82NoTimeout bool83// DisableCookie disables cookie reuse for the http client (cookiejar impl)84DisableCookie bool85// FollowRedirects specifies the redirects flow86RedirectFlow RedirectFlow87// Connection defines custom connection configuration88Connection *ConnectionConfiguration89// ResponseHeaderTimeout is the timeout for response body to be read from the server90ResponseHeaderTimeout time.Duration91}9293func (c *Configuration) Clone() *Configuration {94clone := *c95if c.Connection != nil {96cloneConnection := &ConnectionConfiguration{97DisableKeepAlive: c.Connection.DisableKeepAlive,98CustomMaxTimeout: c.Connection.CustomMaxTimeout,99}100if c.Connection.HasCookieJar() {101cookiejar := *c.Connection.GetCookieJar()102cloneConnection.SetCookieJar(&cookiejar)103}104clone.Connection = cloneConnection105}106107return &clone108}109110// Hash returns the hash of the configuration to allow client pooling111func (c *Configuration) Hash() string {112builder := &strings.Builder{}113builder.Grow(16)114builder.WriteString("t")115builder.WriteString(strconv.Itoa(c.Threads))116builder.WriteString("m")117builder.WriteString(strconv.Itoa(c.MaxRedirects))118builder.WriteString("n")119builder.WriteString(strconv.FormatBool(c.NoTimeout))120builder.WriteString("f")121builder.WriteString(strconv.Itoa(int(c.RedirectFlow)))122builder.WriteString("r")123builder.WriteString(strconv.FormatBool(c.DisableCookie))124builder.WriteString("c")125builder.WriteString(strconv.FormatBool(c.Connection != nil))126if c.Connection != nil && c.Connection.CustomMaxTimeout > 0 {127builder.WriteString("k")128builder.WriteString(c.Connection.CustomMaxTimeout.String())129}130builder.WriteString("r")131builder.WriteString(strconv.FormatInt(int64(c.ResponseHeaderTimeout.Seconds()), 10))132hash := builder.String()133return hash134}135136// HasStandardOptions checks whether the configuration requires custom settings137func (c *Configuration) HasStandardOptions() bool {138return c.Threads == 0 && c.MaxRedirects == 0 && c.RedirectFlow == DontFollowRedirect && c.DisableCookie && c.Connection == nil && !c.NoTimeout && c.ResponseHeaderTimeout == 0139}140141// GetRawHTTP returns the rawhttp request client142func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client {143dialers := protocolstate.GetDialersWithId(options.Options.ExecutionId)144if dialers == nil {145panic("dialers not initialized for execution id: " + options.Options.ExecutionId)146}147148// Lock the dialers to avoid a race when setting RawHTTPClient149dialers.Lock()150defer dialers.Unlock()151152if dialers.RawHTTPClient != nil {153return dialers.RawHTTPClient154}155156rawHttpOptionsCopy := *rawhttp.DefaultOptions157if options.Options.AliveHttpProxy != "" {158rawHttpOptionsCopy.Proxy = options.Options.AliveHttpProxy159} else if options.Options.AliveSocksProxy != "" {160rawHttpOptionsCopy.Proxy = options.Options.AliveSocksProxy161} else if dialers.Fastdialer != nil {162rawHttpOptionsCopy.FastDialer = dialers.Fastdialer163}164rawHttpOptionsCopy.Timeout = options.Options.GetTimeouts().HttpTimeout165dialers.RawHTTPClient = rawhttp.NewClient(&rawHttpOptionsCopy)166return dialers.RawHTTPClient167}168169// Get creates or gets a client for the protocol based on custom configuration170func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {171if configuration.HasStandardOptions() {172dialers := protocolstate.GetDialersWithId(options.ExecutionId)173if dialers == nil {174return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)175}176return dialers.DefaultHTTPClient, nil177}178179return wrappedGet(options, configuration)180}181182// wrappedGet wraps a get operation without normal client check183func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {184var err error185186dialers := protocolstate.GetDialersWithId(options.ExecutionId)187if dialers == nil {188return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)189}190191hash := configuration.Hash()192if client, ok := dialers.HTTPClientPool.Get(hash); ok {193return client, nil194}195196// Multiple Host197retryableHttpOptions := retryablehttp.DefaultOptionsSpraying198disableKeepAlives := true199maxIdleConns := 0200maxConnsPerHost := 0201maxIdleConnsPerHost := -1202// do not split given timeout into chunks for retry203// because this won't work on slow hosts204retryableHttpOptions.NoAdjustTimeout = true205206if configuration.Threads > 0 || options.ScanStrategy == scanstrategy.HostSpray.String() {207// Single host208retryableHttpOptions = retryablehttp.DefaultOptionsSingle209disableKeepAlives = false210maxIdleConnsPerHost = 500211maxConnsPerHost = 500212}213214retryableHttpOptions.RetryWaitMax = 10 * time.Second215retryableHttpOptions.RetryMax = options.Retries216retryableHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second217if configuration.ResponseHeaderTimeout > 0 && configuration.ResponseHeaderTimeout > retryableHttpOptions.Timeout {218retryableHttpOptions.Timeout = configuration.ResponseHeaderTimeout219}220redirectFlow := configuration.RedirectFlow221maxRedirects := configuration.MaxRedirects222223if forceMaxRedirects > 0 {224// by default we enable general redirects following225switch {226case options.FollowHostRedirects:227redirectFlow = FollowSameHostRedirect228default:229redirectFlow = FollowAllRedirect230}231maxRedirects = forceMaxRedirects232}233if options.DisableRedirects {234options.FollowRedirects = false235options.FollowHostRedirects = false236redirectFlow = DontFollowRedirect237maxRedirects = 0238}239240// override connection's settings if required241if configuration.Connection != nil {242disableKeepAlives = configuration.Connection.DisableKeepAlive243}244245// Set the base TLS configuration definition246tlsConfig := &tls.Config{247Renegotiation: tls.RenegotiateOnceAsClient,248InsecureSkipVerify: true,249MinVersion: tls.VersionTLS10,250ClientSessionCache: tls.NewLRUClientSessionCache(1024),251}252253if options.SNI != "" {254tlsConfig.ServerName = options.SNI255}256257// Add the client certificate authentication to the request if it's configured258tlsConfig, err = utils.AddConfiguredClientCertToRequest(tlsConfig, options)259if err != nil {260return nil, errors.Wrap(err, "could not create client certificate")261}262263// responseHeaderTimeout is max timeout for response headers to be read264responseHeaderTimeout := options.GetTimeouts().HttpResponseHeaderTimeout265if configuration.ResponseHeaderTimeout != 0 {266responseHeaderTimeout = configuration.ResponseHeaderTimeout267}268269if responseHeaderTimeout < retryableHttpOptions.Timeout {270responseHeaderTimeout = retryableHttpOptions.Timeout271}272273if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 {274responseHeaderTimeout = configuration.Connection.CustomMaxTimeout275}276277transport := &http.Transport{278ForceAttemptHTTP2: options.ForceAttemptHTTP2,279DialContext: dialers.Fastdialer.Dial,280DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {281if options.TlsImpersonate {282return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)283}284if options.HasClientCertificates() || options.ForceAttemptHTTP2 {285return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)286}287return dialers.Fastdialer.DialTLS(ctx, network, addr)288},289MaxIdleConns: maxIdleConns,290MaxIdleConnsPerHost: maxIdleConnsPerHost,291MaxConnsPerHost: maxConnsPerHost,292TLSClientConfig: tlsConfig,293DisableKeepAlives: disableKeepAlives,294ResponseHeaderTimeout: responseHeaderTimeout,295}296297if options.AliveHttpProxy != "" {298if proxyURL, err := url.Parse(options.AliveHttpProxy); err == nil {299transport.Proxy = http.ProxyURL(proxyURL)300}301} else if options.AliveSocksProxy != "" {302socksURL, proxyErr := url.Parse(options.AliveSocksProxy)303if proxyErr != nil {304return nil, proxyErr305}306307dialer, err := proxy.FromURL(socksURL, proxy.Direct)308if err != nil {309return nil, err310}311312dc := dialer.(interface {313DialContext(ctx context.Context, network, addr string) (net.Conn, error)314})315316transport.DialContext = dc.DialContext317transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {318// upgrade proxy connection to tls319conn, err := dc.DialContext(ctx, network, addr)320if err != nil {321return nil, err322}323if tlsConfig.ServerName == "" {324// addr should be in form of host:port already set from canonicalAddr325host, _, err := net.SplitHostPort(addr)326if err != nil {327return nil, err328}329tlsConfig.ServerName = host330}331return tls.Client(conn, tlsConfig), nil332}333}334335var jar *cookiejar.Jar336if configuration.Connection != nil && configuration.Connection.HasCookieJar() {337jar = configuration.Connection.GetCookieJar()338} else if !configuration.DisableCookie {339if jar, err = cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}); err != nil {340return nil, errors.Wrap(err, "could not create cookiejar")341}342}343344httpclient := &http.Client{345Transport: transport,346CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),347}348if !configuration.NoTimeout {349httpclient.Timeout = options.GetTimeouts().HttpTimeout350if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 {351httpclient.Timeout = configuration.Connection.CustomMaxTimeout352}353}354client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)355if jar != nil {356client.HTTPClient.Jar = jar357}358client.CheckRetry = retryablehttp.HostSprayRetryPolicy()359360// Only add to client pool if we don't have a cookie jar in place.361if jar == nil {362if err := dialers.HTTPClientPool.Set(hash, client); err != nil {363return nil, err364}365}366return client, nil367}368369type RedirectFlow uint8370371const (372DontFollowRedirect RedirectFlow = iota373FollowSameHostRedirect374FollowAllRedirect375)376377const defaultMaxRedirects = 10378379type checkRedirectFunc func(req *http.Request, via []*http.Request) error380381func makeCheckRedirectFunc(redirectType RedirectFlow, maxRedirects int) checkRedirectFunc {382return func(req *http.Request, via []*http.Request) error {383switch redirectType {384case DontFollowRedirect:385return http.ErrUseLastResponse386case FollowSameHostRedirect:387var newHost = req.URL.Host388var oldHost = via[0].Host389if oldHost == "" {390oldHost = via[0].URL.Host391}392if newHost != oldHost {393// Tell the http client to not follow redirect394return http.ErrUseLastResponse395}396return checkMaxRedirects(req, via, maxRedirects)397case FollowAllRedirect:398return checkMaxRedirects(req, via, maxRedirects)399}400return nil401}402}403404func checkMaxRedirects(req *http.Request, via []*http.Request, maxRedirects int) error {405if maxRedirects == 0 {406if len(via) > defaultMaxRedirects {407return http.ErrUseLastResponse408}409return nil410}411412if len(via) > maxRedirects {413return http.ErrUseLastResponse414}415416// NOTE(dwisiswant0): rebuild request URL. See #5900.417if u := req.URL.String(); !isURLEncoded(u) {418parsed, err := urlutil.Parse(u)419if err != nil {420return fmt.Errorf("%w: %w", ErrRebuildURL, err)421}422423req.URL = parsed.URL424}425426return nil427}428429// isURLEncoded is an helper function to check if the URL is already encoded430//431// NOTE(dwisiswant0): shall we move this under `projectdiscovery/utils/urlutil`?432func isURLEncoded(s string) bool {433decoded, err := url.QueryUnescape(s)434if err != nil {435// If decoding fails, it may indicate a malformed URL/invalid encoding.436return false437}438439return decoded != s440}441442443