github.com/blend/go-sdk@v1.20240719.1/web/route_tree.go (about) 1 /* 2 3 Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package web 9 10 import ( 11 "net/http" 12 13 "github.com/blend/go-sdk/webutil" 14 ) 15 16 // RouteTree implements the basic logic of creating a route tree. 17 // 18 // It is embedded in *App and handles the path to handler matching based 19 // on route trees per method. 20 // 21 // A very simple example: 22 // 23 // rt := new(web.RouteTree) 24 // rt.Handle(http.MethodGet, "/", func(w http.ResponseWriter, req *http.Request, route *web.Route, params web.Params) { 25 // w.WriteHeader(http.StatusOK) 26 // fmt.Fprintf(w, "OK!") 27 // }) 28 // (&http.Server{Addr: "127.0.0.1:8080", Handler: rt}).ListenAndServe() 29 type RouteTree struct { 30 // Routes is a map between canonicalized http method 31 // (i.e. `GET` vs. `get`) and individual method 32 // route trees. 33 Routes map[string]*RouteNode 34 // SkipTrailingSlashRedirects disables matching 35 // routes that are off by a trailing slash, either because 36 // routes are registered with the '/' suffix, or because 37 // the request has a '/' suffix and the 38 // registered route does not. 39 SkipTrailingSlashRedirects bool 40 // SkipHandlingMethodOptions disables returning 41 // a result with the `ALLOWED` header for method options, 42 // and will instead 404 for `OPTIONS` methods. 43 SkipHandlingMethodOptions bool 44 // SkipMethodNotAllowed skips specific handling 45 // for methods that do not have a route tree with 46 // a specific 405 response, and will instead return a 404. 47 SkipMethodNotAllowed bool 48 // NotFoundHandler is an optional handler to set 49 // to customize not found (404) results. 50 NotFoundHandler Handler 51 // MethodNotAllowedHandler is an optional handler 52 // to set to customize method not allowed (405) results. 53 MethodNotAllowedHandler Handler 54 } 55 56 // Handle adds a handler at a given method and path. 57 func (rt *RouteTree) Handle(method, path string, handler Handler) { 58 if len(path) == 0 { 59 panic("path must not be empty") 60 } 61 if path[0] != '/' { 62 panic("path must begin with '/' in path '" + path + "'") 63 } 64 if rt.Routes == nil { 65 rt.Routes = make(map[string]*RouteNode) 66 } 67 68 root := rt.Routes[method] 69 if root == nil { 70 root = new(RouteNode) 71 rt.Routes[method] = root 72 } 73 root.AddRoute(method, path, handler) 74 } 75 76 // Route gets the route and parameters for a given request 77 // if it matches a registered handler. 78 // 79 // It will automatically resolve if a trailing slash should be appended 80 // for the input request url path, and will return the corresponding redirected 81 // route (and parameters) if there is one. 82 func (rt *RouteTree) Route(req *http.Request) (*Route, RouteParameters) { 83 path := req.URL.Path 84 methodRoot := rt.Routes[req.Method] 85 if methodRoot != nil { 86 route, params, shouldRedirectTrailingSlash := methodRoot.getValue(path) 87 if req.Method != http.MethodConnect && path != "/" { 88 if shouldRedirectTrailingSlash && !rt.SkipTrailingSlashRedirects { 89 route, params, _ = methodRoot.getValue(rt.withPathAlternateTrailingSlash(path)) 90 } 91 } 92 return route, params 93 } 94 return nil, nil 95 } 96 97 // ServeHTTP makes the router implement the http.Handler interface. 98 func (rt *RouteTree) ServeHTTP(w http.ResponseWriter, req *http.Request) { 99 path := req.URL.Path 100 if root := rt.Routes[req.Method]; root != nil { 101 route, params, trailingSlashRedirect := root.getValue(path) 102 if route != nil { 103 route.Handler(w, req, route, params) 104 return 105 } else if req.Method != http.MethodConnect && path != "/" { 106 if trailingSlashRedirect && !rt.SkipTrailingSlashRedirects { 107 rt.redirectTrailingSlash(w, req) 108 return 109 } 110 } 111 } 112 113 if req.Method == http.MethodOptions { 114 // Handle OPTIONS requests 115 if !rt.SkipHandlingMethodOptions { 116 if allow := rt.allowed(path, req.Method); allow != "" { 117 w.Header().Set(webutil.HeaderAllow, allow) 118 // just return the allowed header 119 return 120 } 121 // return a 404 below 122 } 123 } else { 124 // Handle 405 125 if !rt.SkipMethodNotAllowed { 126 if allow := rt.allowed(path, req.Method); len(allow) > 0 { 127 w.Header().Set(webutil.HeaderAllow, allow) 128 if rt.MethodNotAllowedHandler != nil { 129 rt.MethodNotAllowedHandler(w, req, nil, nil) 130 return 131 } 132 http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) 133 return 134 } 135 } 136 } 137 138 // Handle 404 139 if rt.NotFoundHandler != nil { 140 rt.NotFoundHandler(w, req, nil, nil) 141 } else { 142 http.NotFound(w, req) 143 } 144 } 145 146 // 147 // internal helpers 148 // 149 150 // withPathAlternateTrailingSlash returns the request with a `/` suffix on the url path. 151 func (rt *RouteTree) withPathAlternateTrailingSlash(path string) string { 152 // if the path has a slash already, try removing it 153 if len(path) > 1 && path[len(path)-1] == '/' { 154 // try removing the slash 155 return path[:len(path)-1] 156 } 157 if len(path) > 0 { 158 // try adding the slash 159 return path + "/" 160 } 161 return path 162 } 163 164 // redirectTrailingSlash redirects the request if a suffix trailing 165 // forward slash should be added. 166 func (rt *RouteTree) redirectTrailingSlash(w http.ResponseWriter, req *http.Request) { 167 code := http.StatusMovedPermanently // 301 // Permanent redirect, request with GET method 168 if req.Method != http.MethodGet { 169 code = http.StatusTemporaryRedirect // 307 170 } 171 req.URL.Path = rt.withPathAlternateTrailingSlash(req.URL.Path) 172 http.Redirect(w, req, req.URL.String(), code) 173 return 174 } 175 176 func (rt *RouteTree) allowed(path, reqMethod string) (allow string) { 177 if path == "*" { // server-wide 178 for method := range rt.Routes { 179 if method == http.MethodOptions { 180 continue 181 } 182 183 // add request method to list of allowed methods 184 if allow == "" { 185 allow = method 186 } else { 187 allow += ", " + method 188 } 189 } 190 return 191 } 192 for method := range rt.Routes { 193 // Skip the requested method - we already tried this one 194 if method == reqMethod || method == http.MethodOptions { 195 continue 196 } 197 198 handle, _, _ := rt.Routes[method].getValue(path) 199 if handle != nil { 200 // add request method to list of allowed methods 201 if allow == "" { 202 allow = method 203 } else { 204 allow += ", " + method 205 } 206 } 207 } 208 if allow != "" { 209 allow += ", " + http.MethodOptions 210 } 211 return 212 }