github.com/xmidt-org/webpa-common@v1.11.9/xhttp/redirectPolicy.go (about)

     1  package xhttp
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/textproto"
     7  
     8  	"github.com/go-kit/kit/log"
     9  	"github.com/go-kit/kit/log/level"
    10  	"github.com/xmidt-org/webpa-common/logging"
    11  )
    12  
    13  const (
    14  	DefaultMaxRedirects = 10
    15  )
    16  
    17  // RedirectPolicy is the configurable policy for handling redirects
    18  type RedirectPolicy struct {
    19  	// Logger is the go-kit Logger used for logging.  If unset, the request context's logger is used.
    20  	Logger log.Logger
    21  
    22  	// MaxRedirects is the maximum number of redirects to follow.  If unset, DefaultMaxRedirects is used.
    23  	MaxRedirects int
    24  
    25  	// ExcludeHeaders is the denylist of headers that should not be copied from previous requests.
    26  	ExcludeHeaders []string
    27  }
    28  
    29  // maxRedirects returns the maximum number of redirects to follow
    30  func (p RedirectPolicy) maxRedirects() int {
    31  	if p.MaxRedirects > 0 {
    32  		return p.MaxRedirects
    33  	}
    34  
    35  	return DefaultMaxRedirects
    36  }
    37  
    38  // headerFilter returns a closure that returns true if a header name should be included in redirected requests
    39  func (p RedirectPolicy) headerFilter() func(string) bool {
    40  	if len(p.ExcludeHeaders) > 0 {
    41  		excludes := make(map[string]bool, len(p.ExcludeHeaders))
    42  		for _, v := range p.ExcludeHeaders {
    43  			excludes[textproto.CanonicalMIMEHeaderKey(v)] = true
    44  		}
    45  
    46  		return func(h string) bool {
    47  			return !excludes[h]
    48  		}
    49  	}
    50  
    51  	return func(string) bool {
    52  		return true
    53  	}
    54  }
    55  
    56  // CheckRedirect produces a redirect policy function given a policy descriptor
    57  func CheckRedirect(p RedirectPolicy) func(*http.Request, []*http.Request) error {
    58  	var (
    59  		maxRedirects = p.maxRedirects()
    60  		headerFilter = p.headerFilter()
    61  	)
    62  
    63  	return func(r *http.Request, via []*http.Request) error {
    64  		logger := p.Logger
    65  		if logger == nil {
    66  			logger = logging.GetLogger(r.Context())
    67  		}
    68  
    69  		if len(via) >= maxRedirects {
    70  			err := fmt.Errorf("stopped after %d redirect(s)", maxRedirects)
    71  			logger.Log(level.Key(), level.ErrorValue(), logging.ErrorKey(), err)
    72  			return err
    73  		}
    74  
    75  		for k, v := range via[len(via)-1].Header {
    76  			if headerFilter(k) {
    77  				r.Header[k] = v
    78  			} else {
    79  				logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "excluding header on redirect", "header", k)
    80  			}
    81  		}
    82  
    83  		return nil
    84  	}
    85  }