github.com/xmidt-org/webpa-common@v1.11.9/xhttp/redirectPolicy.go (about) 1 package xhttp 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/textproto" 7 8 "github.com/go-kit/kit/log" 9 "github.com/go-kit/kit/log/level" 10 "github.com/xmidt-org/webpa-common/logging" 11 ) 12 13 const ( 14 DefaultMaxRedirects = 10 15 ) 16 17 // RedirectPolicy is the configurable policy for handling redirects 18 type RedirectPolicy struct { 19 // Logger is the go-kit Logger used for logging. If unset, the request context's logger is used. 20 Logger log.Logger 21 22 // MaxRedirects is the maximum number of redirects to follow. If unset, DefaultMaxRedirects is used. 23 MaxRedirects int 24 25 // ExcludeHeaders is the denylist of headers that should not be copied from previous requests. 26 ExcludeHeaders []string 27 } 28 29 // maxRedirects returns the maximum number of redirects to follow 30 func (p RedirectPolicy) maxRedirects() int { 31 if p.MaxRedirects > 0 { 32 return p.MaxRedirects 33 } 34 35 return DefaultMaxRedirects 36 } 37 38 // headerFilter returns a closure that returns true if a header name should be included in redirected requests 39 func (p RedirectPolicy) headerFilter() func(string) bool { 40 if len(p.ExcludeHeaders) > 0 { 41 excludes := make(map[string]bool, len(p.ExcludeHeaders)) 42 for _, v := range p.ExcludeHeaders { 43 excludes[textproto.CanonicalMIMEHeaderKey(v)] = true 44 } 45 46 return func(h string) bool { 47 return !excludes[h] 48 } 49 } 50 51 return func(string) bool { 52 return true 53 } 54 } 55 56 // CheckRedirect produces a redirect policy function given a policy descriptor 57 func CheckRedirect(p RedirectPolicy) func(*http.Request, []*http.Request) error { 58 var ( 59 maxRedirects = p.maxRedirects() 60 headerFilter = p.headerFilter() 61 ) 62 63 return func(r *http.Request, via []*http.Request) error { 64 logger := p.Logger 65 if logger == nil { 66 logger = logging.GetLogger(r.Context()) 67 } 68 69 if len(via) >= maxRedirects { 70 err := fmt.Errorf("stopped after %d redirect(s)", maxRedirects) 71 logger.Log(level.Key(), level.ErrorValue(), logging.ErrorKey(), err) 72 return err 73 } 74 75 for k, v := range via[len(via)-1].Header { 76 if headerFilter(k) { 77 r.Header[k] = v 78 } else { 79 logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "excluding header on redirect", "header", k) 80 } 81 } 82 83 return nil 84 } 85 }