github.com/projectdiscovery/nuclei/v2@v2.9.15/pkg/protocols/http/httpclientpool/clientpool.go (about)

     1  package httpclientpool
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"net"
     7  	"net/http"
     8  	"net/http/cookiejar"
     9  	"net/url"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/pkg/errors"
    16  	"golang.org/x/net/proxy"
    17  	"golang.org/x/net/publicsuffix"
    18  
    19  	"github.com/projectdiscovery/fastdialer/fastdialer"
    20  	"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
    21  	"github.com/projectdiscovery/nuclei/v2/pkg/protocols/common/protocolstate"
    22  	"github.com/projectdiscovery/nuclei/v2/pkg/protocols/utils"
    23  	"github.com/projectdiscovery/nuclei/v2/pkg/types"
    24  	"github.com/projectdiscovery/nuclei/v2/pkg/types/scanstrategy"
    25  	"github.com/projectdiscovery/rawhttp"
    26  	"github.com/projectdiscovery/retryablehttp-go"
    27  	mapsutil "github.com/projectdiscovery/utils/maps"
    28  )
    29  
    30  var (
    31  	// Dialer is a copy of the fastdialer from protocolstate
    32  	Dialer *fastdialer.Dialer
    33  
    34  	rawHttpClient     *rawhttp.Client
    35  	forceMaxRedirects int
    36  	normalClient      *retryablehttp.Client
    37  	clientPool        *mapsutil.SyncLockMap[string, *retryablehttp.Client]
    38  )
    39  
    40  // Init initializes the clientpool implementation
    41  func Init(options *types.Options) error {
    42  	// Don't create clients if already created in the past.
    43  	if normalClient != nil {
    44  		return nil
    45  	}
    46  	if options.ShouldFollowHTTPRedirects() {
    47  		forceMaxRedirects = options.MaxRedirects
    48  	}
    49  	clientPool = &mapsutil.SyncLockMap[string, *retryablehttp.Client]{
    50  		Map: make(mapsutil.Map[string, *retryablehttp.Client]),
    51  	}
    52  
    53  	client, err := wrappedGet(options, &Configuration{})
    54  	if err != nil {
    55  		return err
    56  	}
    57  	normalClient = client
    58  	return nil
    59  }
    60  
    61  // ConnectionConfiguration contains the custom configuration options for a connection
    62  type ConnectionConfiguration struct {
    63  	// DisableKeepAlive of the connection
    64  	DisableKeepAlive bool
    65  	cookiejar        *cookiejar.Jar
    66  	mu               sync.RWMutex
    67  }
    68  
    69  func (cc *ConnectionConfiguration) SetCookieJar(cookiejar *cookiejar.Jar) {
    70  	cc.mu.Lock()
    71  	defer cc.mu.Unlock()
    72  
    73  	cc.cookiejar = cookiejar
    74  }
    75  
    76  func (cc *ConnectionConfiguration) GetCookieJar() *cookiejar.Jar {
    77  	cc.mu.RLock()
    78  	defer cc.mu.RUnlock()
    79  
    80  	return cc.cookiejar
    81  }
    82  
    83  func (cc *ConnectionConfiguration) HasCookieJar() bool {
    84  	cc.mu.RLock()
    85  	defer cc.mu.RUnlock()
    86  
    87  	return cc.cookiejar != nil
    88  }
    89  
    90  // Configuration contains the custom configuration options for a client
    91  type Configuration struct {
    92  	// Threads contains the threads for the client
    93  	Threads int
    94  	// MaxRedirects is the maximum number of redirects to follow
    95  	MaxRedirects int
    96  	// NoTimeout disables http request timeout for context based usage
    97  	NoTimeout bool
    98  	// CookieReuse enables cookie reuse for the http client (cookiejar impl)
    99  	CookieReuse bool
   100  	// FollowRedirects specifies the redirects flow
   101  	RedirectFlow RedirectFlow
   102  	// Connection defines custom connection configuration
   103  	Connection *ConnectionConfiguration
   104  }
   105  
   106  // Hash returns the hash of the configuration to allow client pooling
   107  func (c *Configuration) Hash() string {
   108  	builder := &strings.Builder{}
   109  	builder.Grow(16)
   110  	builder.WriteString("t")
   111  	builder.WriteString(strconv.Itoa(c.Threads))
   112  	builder.WriteString("m")
   113  	builder.WriteString(strconv.Itoa(c.MaxRedirects))
   114  	builder.WriteString("n")
   115  	builder.WriteString(strconv.FormatBool(c.NoTimeout))
   116  	builder.WriteString("f")
   117  	builder.WriteString(strconv.Itoa(int(c.RedirectFlow)))
   118  	builder.WriteString("r")
   119  	builder.WriteString(strconv.FormatBool(c.CookieReuse))
   120  	builder.WriteString("c")
   121  	builder.WriteString(strconv.FormatBool(c.Connection != nil))
   122  	hash := builder.String()
   123  	return hash
   124  }
   125  
   126  // HasStandardOptions checks whether the configuration requires custom settings
   127  func (c *Configuration) HasStandardOptions() bool {
   128  	return c.Threads == 0 && c.MaxRedirects == 0 && c.RedirectFlow == DontFollowRedirect && !c.CookieReuse && c.Connection == nil && !c.NoTimeout
   129  }
   130  
   131  // GetRawHTTP returns the rawhttp request client
   132  func GetRawHTTP(options *types.Options) *rawhttp.Client {
   133  	if rawHttpClient == nil {
   134  		rawHttpOptions := rawhttp.DefaultOptions
   135  		if types.ProxyURL != "" {
   136  			rawHttpOptions.Proxy = types.ProxyURL
   137  		} else if types.ProxySocksURL != "" {
   138  			rawHttpOptions.Proxy = types.ProxySocksURL
   139  		} else if Dialer != nil {
   140  			rawHttpOptions.FastDialer = Dialer
   141  		}
   142  		rawHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second
   143  		rawHttpClient = rawhttp.NewClient(rawHttpOptions)
   144  	}
   145  	return rawHttpClient
   146  }
   147  
   148  // Get creates or gets a client for the protocol based on custom configuration
   149  func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
   150  	if configuration.HasStandardOptions() {
   151  		return normalClient, nil
   152  	}
   153  	return wrappedGet(options, configuration)
   154  }
   155  
   156  // wrappedGet wraps a get operation without normal client check
   157  func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
   158  	var err error
   159  
   160  	if Dialer == nil {
   161  		Dialer = protocolstate.Dialer
   162  	}
   163  
   164  	hash := configuration.Hash()
   165  	if client, ok := clientPool.Get(hash); ok {
   166  		return client, nil
   167  	}
   168  
   169  	// Multiple Host
   170  	retryableHttpOptions := retryablehttp.DefaultOptionsSpraying
   171  	disableKeepAlives := true
   172  	maxIdleConns := 0
   173  	maxConnsPerHost := 0
   174  	maxIdleConnsPerHost := -1
   175  
   176  	if configuration.Threads > 0 || options.ScanStrategy == scanstrategy.HostSpray.String() {
   177  		// Single host
   178  		retryableHttpOptions = retryablehttp.DefaultOptionsSingle
   179  		disableKeepAlives = false
   180  		maxIdleConnsPerHost = 500
   181  		maxConnsPerHost = 500
   182  	}
   183  
   184  	retryableHttpOptions.RetryWaitMax = 10 * time.Second
   185  	retryableHttpOptions.RetryMax = options.Retries
   186  	redirectFlow := configuration.RedirectFlow
   187  	maxRedirects := configuration.MaxRedirects
   188  
   189  	if forceMaxRedirects > 0 {
   190  		// by default we enable general redirects following
   191  		switch {
   192  		case options.FollowHostRedirects:
   193  			redirectFlow = FollowSameHostRedirect
   194  		default:
   195  			redirectFlow = FollowAllRedirect
   196  		}
   197  		maxRedirects = forceMaxRedirects
   198  	}
   199  	if options.DisableRedirects {
   200  		options.FollowRedirects = false
   201  		options.FollowHostRedirects = false
   202  		redirectFlow = DontFollowRedirect
   203  		maxRedirects = 0
   204  	}
   205  
   206  	// override connection's settings if required
   207  	if configuration.Connection != nil {
   208  		disableKeepAlives = configuration.Connection.DisableKeepAlive
   209  	}
   210  
   211  	// Set the base TLS configuration definition
   212  	tlsConfig := &tls.Config{
   213  		Renegotiation:      tls.RenegotiateOnceAsClient,
   214  		InsecureSkipVerify: true,
   215  		MinVersion:         tls.VersionTLS10,
   216  	}
   217  
   218  	if options.SNI != "" {
   219  		tlsConfig.ServerName = options.SNI
   220  	}
   221  
   222  	// Add the client certificate authentication to the request if it's configured
   223  	tlsConfig, err = utils.AddConfiguredClientCertToRequest(tlsConfig, options)
   224  	if err != nil {
   225  		return nil, errors.Wrap(err, "could not create client certificate")
   226  	}
   227  
   228  	transport := &http.Transport{
   229  		ForceAttemptHTTP2: options.ForceAttemptHTTP2,
   230  		DialContext:       Dialer.Dial,
   231  		DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   232  			if options.HasClientCertificates() {
   233  				return Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
   234  			}
   235  			if options.TlsImpersonate {
   236  				return Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
   237  			}
   238  			return Dialer.DialTLS(ctx, network, addr)
   239  		},
   240  		MaxIdleConns:        maxIdleConns,
   241  		MaxIdleConnsPerHost: maxIdleConnsPerHost,
   242  		MaxConnsPerHost:     maxConnsPerHost,
   243  		TLSClientConfig:     tlsConfig,
   244  		DisableKeepAlives:   disableKeepAlives,
   245  	}
   246  
   247  	if types.ProxyURL != "" {
   248  		if proxyURL, err := url.Parse(types.ProxyURL); err == nil {
   249  			transport.Proxy = http.ProxyURL(proxyURL)
   250  		}
   251  	} else if types.ProxySocksURL != "" {
   252  		socksURL, proxyErr := url.Parse(types.ProxySocksURL)
   253  		if proxyErr != nil {
   254  			return nil, proxyErr
   255  		}
   256  
   257  		dialer, err := proxy.FromURL(socksURL, proxy.Direct)
   258  		if err != nil {
   259  			return nil, err
   260  		}
   261  
   262  		dc := dialer.(interface {
   263  			DialContext(ctx context.Context, network, addr string) (net.Conn, error)
   264  		})
   265  
   266  		transport.DialContext = dc.DialContext
   267  		transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
   268  			// upgrade proxy connection to tls
   269  			conn, err := dc.DialContext(ctx, network, addr)
   270  			if err != nil {
   271  				return nil, err
   272  			}
   273  			return tls.Client(conn, tlsConfig), nil
   274  		}
   275  	}
   276  
   277  	var jar *cookiejar.Jar
   278  	if configuration.Connection != nil && configuration.Connection.HasCookieJar() {
   279  		jar = configuration.Connection.GetCookieJar()
   280  	} else if configuration.CookieReuse {
   281  		if jar, err = cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}); err != nil {
   282  			return nil, errors.Wrap(err, "could not create cookiejar")
   283  		}
   284  	}
   285  
   286  	httpclient := &http.Client{
   287  		Transport:     transport,
   288  		CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),
   289  	}
   290  	if !configuration.NoTimeout {
   291  		httpclient.Timeout = time.Duration(options.Timeout) * time.Second
   292  	}
   293  	client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)
   294  	if jar != nil {
   295  		client.HTTPClient.Jar = jar
   296  	}
   297  	client.CheckRetry = retryablehttp.HostSprayRetryPolicy()
   298  
   299  	// Only add to client pool if we don't have a cookie jar in place.
   300  	if jar == nil {
   301  		if err := clientPool.Set(hash, client); err != nil {
   302  			return nil, err
   303  		}
   304  	}
   305  	return client, nil
   306  }
   307  
   308  type RedirectFlow uint8
   309  
   310  const (
   311  	DontFollowRedirect RedirectFlow = iota
   312  	FollowSameHostRedirect
   313  	FollowAllRedirect
   314  )
   315  
   316  const defaultMaxRedirects = 10
   317  
   318  type checkRedirectFunc func(req *http.Request, via []*http.Request) error
   319  
   320  func makeCheckRedirectFunc(redirectType RedirectFlow, maxRedirects int) checkRedirectFunc {
   321  	return func(req *http.Request, via []*http.Request) error {
   322  		switch redirectType {
   323  		case DontFollowRedirect:
   324  			return http.ErrUseLastResponse
   325  		case FollowSameHostRedirect:
   326  			var newHost = req.URL.Host
   327  			var oldHost = via[0].Host
   328  			if oldHost == "" {
   329  				oldHost = via[0].URL.Host
   330  			}
   331  			if newHost != oldHost {
   332  				// Tell the http client to not follow redirect
   333  				return http.ErrUseLastResponse
   334  			}
   335  			return checkMaxRedirects(req, via, maxRedirects)
   336  		case FollowAllRedirect:
   337  			return checkMaxRedirects(req, via, maxRedirects)
   338  		}
   339  		return nil
   340  	}
   341  }
   342  
   343  func checkMaxRedirects(req *http.Request, via []*http.Request, maxRedirects int) error {
   344  	if maxRedirects == 0 {
   345  		if len(via) > defaultMaxRedirects {
   346  			return http.ErrUseLastResponse
   347  		}
   348  		return nil
   349  	}
   350  
   351  	if len(via) > maxRedirects {
   352  		return http.ErrUseLastResponse
   353  	}
   354  	return nil
   355  }