git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/httpx/hostrouter/hostrouter.go (about)

     1  package hostrouter
     2  
     3  import (
     4  	"net/http"
     5  	"strings"
     6  
     7  	"github.com/go-chi/chi/v5"
     8  )
     9  
    10  type Routes map[string]chi.Router
    11  
    12  var _ chi.Routes = Routes{}
    13  
    14  func New() Routes {
    15  	return Routes{}
    16  }
    17  
    18  func (hr Routes) Match(rctx *chi.Context, method, path string) bool {
    19  	return true
    20  }
    21  
    22  func (hr Routes) Map(host string, h chi.Router) {
    23  	hr[strings.ToLower(host)] = h
    24  }
    25  
    26  func (hr Routes) Unmap(host string) {
    27  	delete(hr, strings.ToLower(host))
    28  }
    29  
    30  func (hr Routes) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    31  	host := requestHost(r)
    32  	if router, ok := hr[strings.ToLower(host)]; ok {
    33  		router.ServeHTTP(w, r)
    34  		return
    35  	}
    36  	if router, ok := hr[strings.ToLower(getWildcardHost(host))]; ok {
    37  		router.ServeHTTP(w, r)
    38  		return
    39  	}
    40  	if router, ok := hr["*"]; ok {
    41  		router.ServeHTTP(w, r)
    42  		return
    43  	}
    44  	http.Error(w, http.StatusText(404), 404)
    45  }
    46  
    47  func (hr Routes) Routes() []chi.Route {
    48  	return hr[""].Routes()
    49  }
    50  
    51  func (hr Routes) Middlewares() chi.Middlewares {
    52  	return chi.Middlewares{}
    53  }
    54  
    55  func requestHost(r *http.Request) (host string) {
    56  	// not standard, but most popular
    57  	host = r.Header.Get("X-Forwarded-Host")
    58  	if host != "" {
    59  		return
    60  	}
    61  
    62  	// RFC 7239
    63  	host = r.Header.Get("Forwarded")
    64  	_, _, host = parseForwarded(host)
    65  	if host != "" {
    66  		return
    67  	}
    68  
    69  	// if all else fails fall back to request host
    70  	host = r.Host
    71  	return
    72  }
    73  
    74  func parseForwarded(forwarded string) (addr, proto, host string) {
    75  	if forwarded == "" {
    76  		return
    77  	}
    78  	for _, forwardedPair := range strings.Split(forwarded, ";") {
    79  		if tv := strings.SplitN(forwardedPair, "=", 2); len(tv) == 2 {
    80  			token, value := tv[0], tv[1]
    81  			token = strings.TrimSpace(token)
    82  			value = strings.TrimSpace(strings.Trim(value, `"`))
    83  			switch strings.ToLower(token) {
    84  			case "for":
    85  				addr = value
    86  			case "proto":
    87  				proto = value
    88  			case "host":
    89  				host = value
    90  			}
    91  
    92  		}
    93  	}
    94  	return
    95  }
    96  
    97  func getWildcardHost(host string) string {
    98  	parts := strings.Split(host, ".")
    99  	if len(parts) > 1 {
   100  		wildcard := append([]string{"*"}, parts[1:]...)
   101  		return strings.Join(wildcard, ".")
   102  	}
   103  	return strings.Join(parts, ".")
   104  }