github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/resty/redirect.go (about)

     1  // Copyright (c) 2015-2021 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
     2  // resty source code and usage is governed by a MIT style
     3  // license that can be found in the LICENSE file.
     4  
     5  package resty
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"net"
    11  	"net/http"
    12  	"strings"
    13  )
    14  
    15  type (
    16  	// RedirectPolicy to regulate the redirects in the resty client.
    17  	// Objects implementing the RedirectPolicy interface can be registered as
    18  	//
    19  	// Apply function should return nil to continue the redirect jounery, otherwise
    20  	// return error to stop the redirect.
    21  	RedirectPolicy interface {
    22  		Apply(req *http.Request, via []*http.Request) error
    23  	}
    24  
    25  	// The RedirectPolicyFunc type is an adapter to allow the use of ordinary functions as RedirectPolicy.
    26  	// If f is a function with the appropriate signature, RedirectPolicyFunc(f) is a RedirectPolicy object that calls f.
    27  	RedirectPolicyFunc func(*http.Request, []*http.Request) error
    28  )
    29  
    30  // Apply calls f(req, via).
    31  func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error {
    32  	return f(req, via)
    33  }
    34  
    35  // NoRedirectPolicy is used to disable redirects in the HTTP client
    36  //
    37  //	resty.SetRedirectPolicy(NoRedirectPolicy())
    38  func NoRedirectPolicy() RedirectPolicy {
    39  	return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
    40  		return errors.New("auto redirect is disabled")
    41  	})
    42  }
    43  
    44  // FlexibleRedirectPolicy is convenient method to create No of redirect policy for HTTP client.
    45  //
    46  //	resty.SetRedirectPolicy(FlexibleRedirectPolicy(20))
    47  func FlexibleRedirectPolicy(noOfRedirect int) RedirectPolicy {
    48  	return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
    49  		if len(via) >= noOfRedirect {
    50  			return fmt.Errorf("stopped after %d redirects", noOfRedirect)
    51  		}
    52  		checkHostAndAddHeaders(req, via[0])
    53  		return nil
    54  	})
    55  }
    56  
    57  // DomainCheckRedirectPolicy is convenient method to define domain name redirect rule in resty client.
    58  // Redirect is allowed for only mentioned host in the policy.
    59  //
    60  //	resty.SetRedirectPolicy(DomainCheckRedirectPolicy("host1.com", "host2.org", "host3.net"))
    61  func DomainCheckRedirectPolicy(hostnames ...string) RedirectPolicy {
    62  	hosts := make(map[string]bool)
    63  	for _, h := range hostnames {
    64  		hosts[strings.ToLower(h)] = true
    65  	}
    66  
    67  	fn := RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error {
    68  		if ok := hosts[getHostname(req.URL.Host)]; !ok {
    69  			return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy")
    70  		}
    71  
    72  		return nil
    73  	})
    74  
    75  	return fn
    76  }
    77  
    78  //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
    79  // Package Unexported methods
    80  //_______________________________________________________________________
    81  
    82  func getHostname(host string) (hostname string) {
    83  	if strings.Index(host, ":") > 0 {
    84  		host, _, _ = net.SplitHostPort(host)
    85  	}
    86  	hostname = strings.ToLower(host)
    87  	return
    88  }
    89  
    90  // By default, Golang will not redirect request headers
    91  // after go throughing various discussion comments from thread
    92  // https://github.com/golang/go/issues/4800
    93  // Resty will add all the headers during a redirect for the same host
    94  func checkHostAndAddHeaders(cur *http.Request, pre *http.Request) {
    95  	curHostname := getHostname(cur.URL.Host)
    96  	preHostname := getHostname(pre.URL.Host)
    97  	if strings.EqualFold(curHostname, preHostname) {
    98  		for key, val := range pre.Header {
    99  			cur.Header[key] = val
   100  		}
   101  	} else { // only library User-Agent header is added
   102  		cur.Header.Set(hdrUserAgentKey, hdrUserAgentValue)
   103  	}
   104  }