github.com/searKing/golang/go@v1.2.117/net/http/interceptors.reject.server.go (about)

     1  // Copyright 2020 The searKing Author. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http
     6  
     7  import (
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"strings"
    15  )
    16  
    17  var _ http.Handler = &rejectInsecure{}
    18  
    19  //go:generate go-option -type "rejectInsecure"
    20  type rejectInsecure struct {
    21  	// ErrorLog specifies an optional logger for errors accepting
    22  	// connections, unexpected behavior from handlers, and
    23  	// underlying FileSystem errors.
    24  	// If nil, logging will be done via the log package's standard logger.
    25  	ErrorLog *log.Logger
    26  	// ForceHttp allows any request, as a shortcut circuit
    27  	ForceHttp bool
    28  	// AllowedTlsCidrs allows any request which client or proxy's ip included
    29  	// a cidr is a CIDR notation IP address and prefix length,
    30  	// like "192.0.2.0/24" or "2001:db8::/32", as defined in
    31  	// RFC 4632 and RFC 4291.
    32  	AllowedTlsCidrs []string
    33  
    34  	// WhitelistedPaths allows any request which http path matches
    35  	WhitelistedPaths []string
    36  
    37  	next http.Handler
    38  }
    39  
    40  // RejectInsecureServerInterceptor returns a new server interceptor with tls check.
    41  // reject the request fulfills tls's constraints,
    42  func RejectInsecureServerInterceptor(next http.Handler, opts ...RejectInsecureOption) *rejectInsecure {
    43  	r := &rejectInsecure{
    44  		next: next,
    45  	}
    46  	r.ApplyOptions(opts...)
    47  	return r
    48  }
    49  
    50  func (m *rejectInsecure) logf(format string, args ...any) {
    51  	if m.ErrorLog != nil {
    52  		m.ErrorLog.Printf(format, args...)
    53  	} else {
    54  		log.Printf(format, args...)
    55  	}
    56  }
    57  
    58  func (m *rejectInsecure) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    59  	if m == nil {
    60  		m.next.ServeHTTP(w, r)
    61  		return
    62  	}
    63  	m.rejectInsecureRequests(w, r)
    64  	return
    65  }
    66  
    67  // rejectInsecureRequests refused if tls's constraints not passed
    68  func (m *rejectInsecure) rejectInsecureRequests(w http.ResponseWriter, r *http.Request) {
    69  	if m == nil || m.ForceHttp {
    70  		m.next.ServeHTTP(w, r)
    71  		return
    72  	}
    73  
    74  	err := DoesRequestSatisfyTlsTermination(r, m.WhitelistedPaths, m.AllowedTlsCidrs)
    75  	if err != nil {
    76  		m.logf("http: could not serve http connection %v: %v", r.RemoteAddr, err)
    77  
    78  		w.Header().Set("Content-Type", "application/json")
    79  		w.WriteHeader(http.StatusBadGateway)
    80  
    81  		if err := json.NewEncoder(w).Encode(fmt.Errorf("cannot serve request over insecure http: %w", err)); err != nil {
    82  			// There was an error, but there's actually not a lot we can do except log that this happened.
    83  			m.logf("http: could not write jsonError to response writer %v: %v", http.StatusBadGateway, err)
    84  		}
    85  		return
    86  	}
    87  	m.next.ServeHTTP(w, r)
    88  	return
    89  }
    90  
    91  // DoesRequestSatisfyTlsTermination returns whether the request fulfills tls's constraints,
    92  // https, path matches any whitelisted paths or ip inclued by any cidr
    93  // whitelistedPath is http path that does not need to be checked
    94  // allowedTLSCIDR is the network includes ip.
    95  func DoesRequestSatisfyTlsTermination(r *http.Request, whitelistedPaths []string, allowedTLSCIDRs []string) error {
    96  	// pass if the request is with tls, that is https
    97  	if r.TLS != nil {
    98  		return nil
    99  	}
   100  
   101  	// check if the http request can be passed
   102  
   103  	// pass if the request belongs to whitelist
   104  	for _, p := range whitelistedPaths {
   105  		if r.URL.Path == p {
   106  			return nil
   107  		}
   108  	}
   109  
   110  	if len(allowedTLSCIDRs) == 0 {
   111  		return errors.New("TLS termination is not enabled")
   112  	}
   113  
   114  	if err := matchesAnyCidr(r, allowedTLSCIDRs); err != nil {
   115  		return err
   116  	}
   117  
   118  	proto := r.Header.Get("X-Forwarded-Proto")
   119  	if proto == "" {
   120  		return errors.New("X-Forwarded-Proto header is missing")
   121  	}
   122  	if proto != "https" {
   123  		return fmt.Errorf("expected X-Forwarded-Proto header to be https, got %s", proto)
   124  	}
   125  
   126  	return nil
   127  }
   128  
   129  // matchesAnyCidr returns true if any of client and proxy's ip matches any cidr
   130  // a cidr is a CIDR notation IP address and prefix length,
   131  // like "192.0.2.0/24" or "2001:db8::/32", as defined in
   132  // RFC 4632 and RFC 4291.
   133  func matchesAnyCidr(r *http.Request, cidrs []string) error {
   134  	remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	check := []string{remoteIP}
   140  	// X-Forwarded-For: client1, proxy1, proxy2
   141  	for _, fwd := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") {
   142  		check = append(check, strings.TrimSpace(fwd))
   143  	}
   144  
   145  	for _, rn := range cidrs {
   146  		_, cidr, err := net.ParseCIDR(rn)
   147  		if err != nil {
   148  			return err
   149  		}
   150  
   151  		for _, ip := range check {
   152  			addr := net.ParseIP(ip)
   153  			if cidr.Contains(addr) {
   154  				return nil
   155  			}
   156  		}
   157  	}
   158  	return fmt.Errorf("neither remote address nor any x-forwarded-for values match CIDR cidrs %v: %v, cidrs, check)", cidrs, check)
   159  }