github.com/searKing/golang/go@v1.2.117/net/http/interceptors.reject.server.go (about) 1 // Copyright 2020 The searKing Author. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package http 6 7 import ( 8 "encoding/json" 9 "errors" 10 "fmt" 11 "log" 12 "net" 13 "net/http" 14 "strings" 15 ) 16 17 var _ http.Handler = &rejectInsecure{} 18 19 //go:generate go-option -type "rejectInsecure" 20 type rejectInsecure struct { 21 // ErrorLog specifies an optional logger for errors accepting 22 // connections, unexpected behavior from handlers, and 23 // underlying FileSystem errors. 24 // If nil, logging will be done via the log package's standard logger. 25 ErrorLog *log.Logger 26 // ForceHttp allows any request, as a shortcut circuit 27 ForceHttp bool 28 // AllowedTlsCidrs allows any request which client or proxy's ip included 29 // a cidr is a CIDR notation IP address and prefix length, 30 // like "192.0.2.0/24" or "2001:db8::/32", as defined in 31 // RFC 4632 and RFC 4291. 32 AllowedTlsCidrs []string 33 34 // WhitelistedPaths allows any request which http path matches 35 WhitelistedPaths []string 36 37 next http.Handler 38 } 39 40 // RejectInsecureServerInterceptor returns a new server interceptor with tls check. 41 // reject the request fulfills tls's constraints, 42 func RejectInsecureServerInterceptor(next http.Handler, opts ...RejectInsecureOption) *rejectInsecure { 43 r := &rejectInsecure{ 44 next: next, 45 } 46 r.ApplyOptions(opts...) 47 return r 48 } 49 50 func (m *rejectInsecure) logf(format string, args ...any) { 51 if m.ErrorLog != nil { 52 m.ErrorLog.Printf(format, args...) 53 } else { 54 log.Printf(format, args...) 55 } 56 } 57 58 func (m *rejectInsecure) ServeHTTP(w http.ResponseWriter, r *http.Request) { 59 if m == nil { 60 m.next.ServeHTTP(w, r) 61 return 62 } 63 m.rejectInsecureRequests(w, r) 64 return 65 } 66 67 // rejectInsecureRequests refused if tls's constraints not passed 68 func (m *rejectInsecure) rejectInsecureRequests(w http.ResponseWriter, r *http.Request) { 69 if m == nil || m.ForceHttp { 70 m.next.ServeHTTP(w, r) 71 return 72 } 73 74 err := DoesRequestSatisfyTlsTermination(r, m.WhitelistedPaths, m.AllowedTlsCidrs) 75 if err != nil { 76 m.logf("http: could not serve http connection %v: %v", r.RemoteAddr, err) 77 78 w.Header().Set("Content-Type", "application/json") 79 w.WriteHeader(http.StatusBadGateway) 80 81 if err := json.NewEncoder(w).Encode(fmt.Errorf("cannot serve request over insecure http: %w", err)); err != nil { 82 // There was an error, but there's actually not a lot we can do except log that this happened. 83 m.logf("http: could not write jsonError to response writer %v: %v", http.StatusBadGateway, err) 84 } 85 return 86 } 87 m.next.ServeHTTP(w, r) 88 return 89 } 90 91 // DoesRequestSatisfyTlsTermination returns whether the request fulfills tls's constraints, 92 // https, path matches any whitelisted paths or ip inclued by any cidr 93 // whitelistedPath is http path that does not need to be checked 94 // allowedTLSCIDR is the network includes ip. 95 func DoesRequestSatisfyTlsTermination(r *http.Request, whitelistedPaths []string, allowedTLSCIDRs []string) error { 96 // pass if the request is with tls, that is https 97 if r.TLS != nil { 98 return nil 99 } 100 101 // check if the http request can be passed 102 103 // pass if the request belongs to whitelist 104 for _, p := range whitelistedPaths { 105 if r.URL.Path == p { 106 return nil 107 } 108 } 109 110 if len(allowedTLSCIDRs) == 0 { 111 return errors.New("TLS termination is not enabled") 112 } 113 114 if err := matchesAnyCidr(r, allowedTLSCIDRs); err != nil { 115 return err 116 } 117 118 proto := r.Header.Get("X-Forwarded-Proto") 119 if proto == "" { 120 return errors.New("X-Forwarded-Proto header is missing") 121 } 122 if proto != "https" { 123 return fmt.Errorf("expected X-Forwarded-Proto header to be https, got %s", proto) 124 } 125 126 return nil 127 } 128 129 // matchesAnyCidr returns true if any of client and proxy's ip matches any cidr 130 // a cidr is a CIDR notation IP address and prefix length, 131 // like "192.0.2.0/24" or "2001:db8::/32", as defined in 132 // RFC 4632 and RFC 4291. 133 func matchesAnyCidr(r *http.Request, cidrs []string) error { 134 remoteIP, _, err := net.SplitHostPort(r.RemoteAddr) 135 if err != nil { 136 return err 137 } 138 139 check := []string{remoteIP} 140 // X-Forwarded-For: client1, proxy1, proxy2 141 for _, fwd := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") { 142 check = append(check, strings.TrimSpace(fwd)) 143 } 144 145 for _, rn := range cidrs { 146 _, cidr, err := net.ParseCIDR(rn) 147 if err != nil { 148 return err 149 } 150 151 for _, ip := range check { 152 addr := net.ParseIP(ip) 153 if cidr.Contains(addr) { 154 return nil 155 } 156 } 157 } 158 return fmt.Errorf("neither remote address nor any x-forwarded-for values match CIDR cidrs %v: %v, cidrs, check)", cidrs, check) 159 }