git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/httpx/hostrouter/hostrouter.go (about) 1 package hostrouter 2 3 import ( 4 "net/http" 5 "strings" 6 7 "github.com/go-chi/chi/v5" 8 ) 9 10 type Routes map[string]chi.Router 11 12 var _ chi.Routes = Routes{} 13 14 func New() Routes { 15 return Routes{} 16 } 17 18 func (hr Routes) Match(rctx *chi.Context, method, path string) bool { 19 return true 20 } 21 22 func (hr Routes) Map(host string, h chi.Router) { 23 hr[strings.ToLower(host)] = h 24 } 25 26 func (hr Routes) Unmap(host string) { 27 delete(hr, strings.ToLower(host)) 28 } 29 30 func (hr Routes) ServeHTTP(w http.ResponseWriter, r *http.Request) { 31 host := requestHost(r) 32 if router, ok := hr[strings.ToLower(host)]; ok { 33 router.ServeHTTP(w, r) 34 return 35 } 36 if router, ok := hr[strings.ToLower(getWildcardHost(host))]; ok { 37 router.ServeHTTP(w, r) 38 return 39 } 40 if router, ok := hr["*"]; ok { 41 router.ServeHTTP(w, r) 42 return 43 } 44 http.Error(w, http.StatusText(404), 404) 45 } 46 47 func (hr Routes) Routes() []chi.Route { 48 return hr[""].Routes() 49 } 50 51 func (hr Routes) Middlewares() chi.Middlewares { 52 return chi.Middlewares{} 53 } 54 55 func requestHost(r *http.Request) (host string) { 56 // not standard, but most popular 57 host = r.Header.Get("X-Forwarded-Host") 58 if host != "" { 59 return 60 } 61 62 // RFC 7239 63 host = r.Header.Get("Forwarded") 64 _, _, host = parseForwarded(host) 65 if host != "" { 66 return 67 } 68 69 // if all else fails fall back to request host 70 host = r.Host 71 return 72 } 73 74 func parseForwarded(forwarded string) (addr, proto, host string) { 75 if forwarded == "" { 76 return 77 } 78 for _, forwardedPair := range strings.Split(forwarded, ";") { 79 if tv := strings.SplitN(forwardedPair, "=", 2); len(tv) == 2 { 80 token, value := tv[0], tv[1] 81 token = strings.TrimSpace(token) 82 value = strings.TrimSpace(strings.Trim(value, `"`)) 83 switch strings.ToLower(token) { 84 case "for": 85 addr = value 86 case "proto": 87 proto = value 88 case "host": 89 host = value 90 } 91 92 } 93 } 94 return 95 } 96 97 func getWildcardHost(host string) string { 98 parts := strings.Split(host, ".") 99 if len(parts) > 1 { 100 wildcard := append([]string{"*"}, parts[1:]...) 101 return strings.Join(wildcard, ".") 102 } 103 return strings.Join(parts, ".") 104 }