github.com/blend/go-sdk@v1.20220411.3/web/route_tree.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package web
     9  
    10  import (
    11  	"net/http"
    12  
    13  	"github.com/blend/go-sdk/webutil"
    14  )
    15  
    16  // RouteTree implements the basic logic of creating a route tree.
    17  //
    18  // It is embedded in *App and handles the path to handler matching based
    19  // on route trees per method.
    20  //
    21  // A very simple example:
    22  //
    23  //    rt := new(web.RouteTree)
    24  //    rt.Handle(http.MethodGet, "/", func(w http.ResponseWriter, req *http.Request, route *web.Route, params web.Params) {
    25  //        w.WriteHeader(http.StatusOK)
    26  //        fmt.Fprintf(w, "OK!")
    27  //    })
    28  //    (&http.Server{Addr: "127.0.0.1:8080", Handler: rt}).ListenAndServe()
    29  //
    30  type RouteTree struct {
    31  	// Routes is a map between canonicalized http method
    32  	// (i.e. `GET` vs. `get`) and individual method
    33  	// route trees.
    34  	Routes map[string]*RouteNode
    35  	// SkipTrailingSlashRedirects disables matching
    36  	// routes that are off by a trailing slash, either because
    37  	// routes are registered with the '/' suffix, or because
    38  	// the request has a '/' suffix and the
    39  	// registered route does not.
    40  	SkipTrailingSlashRedirects bool
    41  	// SkipHandlingMethodOptions disables returning
    42  	// a result with the `ALLOWED` header for method options,
    43  	// and will instead 404 for `OPTIONS` methods.
    44  	SkipHandlingMethodOptions bool
    45  	// SkipMethodNotAllowed skips specific handling
    46  	// for methods that do not have a route tree with
    47  	// a specific 405 response, and will instead return a 404.
    48  	SkipMethodNotAllowed bool
    49  	// NotFoundHandler is an optional handler to set
    50  	// to customize not found (404) results.
    51  	NotFoundHandler Handler
    52  	// MethodNotAllowedHandler is an optional handler
    53  	// to set to customize method not allowed (405) results.
    54  	MethodNotAllowedHandler Handler
    55  }
    56  
    57  // Handle adds a handler at a given method and path.
    58  func (rt *RouteTree) Handle(method, path string, handler Handler) {
    59  	if len(path) == 0 {
    60  		panic("path must not be empty")
    61  	}
    62  	if path[0] != '/' {
    63  		panic("path must begin with '/' in path '" + path + "'")
    64  	}
    65  	if rt.Routes == nil {
    66  		rt.Routes = make(map[string]*RouteNode)
    67  	}
    68  
    69  	root := rt.Routes[method]
    70  	if root == nil {
    71  		root = new(RouteNode)
    72  		rt.Routes[method] = root
    73  	}
    74  	root.AddRoute(method, path, handler)
    75  }
    76  
    77  // Route gets the route and parameters for a given request
    78  // if it matches a registered handler.
    79  //
    80  // It will automatically resolve if a trailing slash should be appended
    81  // for the input request url path, and will return the corresponding redirected
    82  // route (and parameters) if there is one.
    83  func (rt *RouteTree) Route(req *http.Request) (*Route, RouteParameters) {
    84  	path := req.URL.Path
    85  	methodRoot := rt.Routes[req.Method]
    86  	if methodRoot != nil {
    87  		route, params, shouldRedirectTrailingSlash := methodRoot.getValue(path)
    88  		if req.Method != http.MethodConnect && path != "/" {
    89  			if shouldRedirectTrailingSlash && !rt.SkipTrailingSlashRedirects {
    90  				route, params, _ = methodRoot.getValue(rt.withPathAlternateTrailingSlash(path))
    91  			}
    92  		}
    93  		return route, params
    94  	}
    95  	return nil, nil
    96  }
    97  
    98  // ServeHTTP makes the router implement the http.Handler interface.
    99  func (rt *RouteTree) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   100  	path := req.URL.Path
   101  	if root := rt.Routes[req.Method]; root != nil {
   102  		route, params, trailingSlashRedirect := root.getValue(path)
   103  		if route != nil {
   104  			route.Handler(w, req, route, params)
   105  			return
   106  		} else if req.Method != http.MethodConnect && path != "/" {
   107  			if trailingSlashRedirect && !rt.SkipTrailingSlashRedirects {
   108  				rt.redirectTrailingSlash(w, req)
   109  				return
   110  			}
   111  		}
   112  	}
   113  
   114  	if req.Method == http.MethodOptions {
   115  		// Handle OPTIONS requests
   116  		if !rt.SkipHandlingMethodOptions {
   117  			if allow := rt.allowed(path, req.Method); allow != "" {
   118  				w.Header().Set(webutil.HeaderAllow, allow)
   119  				// just return the allowed header
   120  				return
   121  			}
   122  			// return a 404 below
   123  		}
   124  	} else {
   125  		// Handle 405
   126  		if !rt.SkipMethodNotAllowed {
   127  			if allow := rt.allowed(path, req.Method); len(allow) > 0 {
   128  				w.Header().Set(webutil.HeaderAllow, allow)
   129  				if rt.MethodNotAllowedHandler != nil {
   130  					rt.MethodNotAllowedHandler(w, req, nil, nil)
   131  					return
   132  				}
   133  				http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
   134  				return
   135  			}
   136  		}
   137  	}
   138  
   139  	// Handle 404
   140  	if rt.NotFoundHandler != nil {
   141  		rt.NotFoundHandler(w, req, nil, nil)
   142  	} else {
   143  		http.NotFound(w, req)
   144  	}
   145  }
   146  
   147  //
   148  // internal helpers
   149  //
   150  
   151  // withPathAlternateTrailingSlash returns the request with a `/` suffix on the url path.
   152  func (rt *RouteTree) withPathAlternateTrailingSlash(path string) string {
   153  	// if the path has a slash already, try removing it
   154  	if len(path) > 1 && path[len(path)-1] == '/' {
   155  		// try removing the slash
   156  		return path[:len(path)-1]
   157  	}
   158  	if len(path) > 0 {
   159  		// try adding the slash
   160  		return path + "/"
   161  	}
   162  	return path
   163  }
   164  
   165  // redirectTrailingSlash redirects the request if a suffix trailing
   166  // forward slash should be added.
   167  func (rt *RouteTree) redirectTrailingSlash(w http.ResponseWriter, req *http.Request) {
   168  	code := http.StatusMovedPermanently // 301 // Permanent redirect, request with GET method
   169  	if req.Method != http.MethodGet {
   170  		code = http.StatusTemporaryRedirect // 307
   171  	}
   172  	req.URL.Path = rt.withPathAlternateTrailingSlash(req.URL.Path)
   173  	http.Redirect(w, req, req.URL.String(), code)
   174  	return
   175  }
   176  
   177  func (rt *RouteTree) allowed(path, reqMethod string) (allow string) {
   178  	if path == "*" { // server-wide
   179  		for method := range rt.Routes {
   180  			if method == http.MethodOptions {
   181  				continue
   182  			}
   183  
   184  			// add request method to list of allowed methods
   185  			if allow == "" {
   186  				allow = method
   187  			} else {
   188  				allow += ", " + method
   189  			}
   190  		}
   191  		return
   192  	}
   193  	for method := range rt.Routes {
   194  		// Skip the requested method - we already tried this one
   195  		if method == reqMethod || method == http.MethodOptions {
   196  			continue
   197  		}
   198  
   199  		handle, _, _ := rt.Routes[method].getValue(path)
   200  		if handle != nil {
   201  			// add request method to list of allowed methods
   202  			if allow == "" {
   203  				allow = method
   204  			} else {
   205  				allow += ", " + method
   206  			}
   207  		}
   208  	}
   209  	if allow != "" {
   210  		allow += ", " + http.MethodOptions
   211  	}
   212  	return
   213  }