github.com/splucs/witchcraft-go-server@v1.7.0/wrouter/router_root.go (about)

     1  // Copyright (c) 2018 Palantir Technologies. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package wrouter
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"sort"
    21  )
    22  
    23  type RootRouter interface {
    24  	http.Handler
    25  	Router
    26  
    27  	AddRequestHandlerMiddleware(handlers ...RequestHandlerMiddleware)
    28  	AddRouteHandlerMiddleware(handlers ...RouteHandlerMiddleware)
    29  }
    30  
    31  type rootRouter struct {
    32  	// impl stores the underlying RouterImpl used to route requests.
    33  	impl RouterImpl
    34  
    35  	// reqHandlers specifies the handlers that run for every request received by the router. Every request received by
    36  	// the router (including requests to methods/paths that are not registered on the router) is handled by these
    37  	// handlers in order before being handled by the underlying RouterImpl.
    38  	reqHandlers []RequestHandlerMiddleware
    39  
    40  	// routeHandlers specifies the handlers that are run for all of the routes that are registered on this router.
    41  	// Requests that are routed to a registered route on this router are handled by these handlers in order before being
    42  	// handled by the registered handler.
    43  	routeHandlers []RouteHandlerMiddleware
    44  
    45  	// routes stores all of the routes that are registered on this router.
    46  	routes []RouteSpec
    47  
    48  	// cachedHandler stores the http.Handler created by chaining all of the request handlers in reqHandlers with the
    49  	// request handler provided by impl. This is done because this http.Handler is called on every request and
    50  	// reqHandlers rarely changes, so it is much more efficient to cache the handler rather than creating a chained one
    51  	// on every request.
    52  	cachedHandler http.Handler
    53  }
    54  
    55  func New(impl RouterImpl, params ...RootRouterParam) RootRouter {
    56  	r := &rootRouter{
    57  		impl: impl,
    58  	}
    59  	for _, p := range params {
    60  		if p == nil {
    61  			continue
    62  		}
    63  		p.configure(r)
    64  	}
    65  	r.updateCachedHandler()
    66  	return r
    67  }
    68  
    69  type RootRouterParam interface {
    70  	configure(*rootRouter)
    71  }
    72  
    73  type rootRouterParamsFunc func(*rootRouter)
    74  
    75  func (f rootRouterParamsFunc) configure(r *rootRouter) {
    76  	f(r)
    77  }
    78  
    79  func RootRouterParamAddRequestHandlerMiddleware(reqHandler ...RequestHandlerMiddleware) RootRouterParam {
    80  	return rootRouterParamsFunc(func(r *rootRouter) {
    81  		r.AddRequestHandlerMiddleware(reqHandler...)
    82  	})
    83  }
    84  
    85  func RootRouterParamAddRouteHandlerMiddleware(routeReqHandler ...RouteHandlerMiddleware) RootRouterParam {
    86  	return rootRouterParamsFunc(func(r *rootRouter) {
    87  		r.AddRouteHandlerMiddleware(routeReqHandler...)
    88  	})
    89  }
    90  
    91  func (r *rootRouter) updateCachedHandler() {
    92  	r.cachedHandler = createRequestHandler(r.impl, r.reqHandlers)
    93  }
    94  
    95  func (r *rootRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    96  	r.cachedHandler.ServeHTTP(w, req)
    97  }
    98  
    99  func (r *rootRouter) Register(method, path string, handler http.Handler, params ...RouteParam) error {
   100  	pathTemplate, err := NewPathTemplate(path)
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	var pathVarNames []string
   106  	for _, segment := range pathTemplate.Segments() {
   107  		if segment.Type == LiteralSegment {
   108  			continue
   109  		}
   110  		pathVarNames = append(pathVarNames, segment.Value)
   111  	}
   112  
   113  	routeSpec := RouteSpec{
   114  		Method:       method,
   115  		PathTemplate: pathTemplate.Template(),
   116  	}
   117  	r.routes = append(r.routes, routeSpec)
   118  	sort.Sort(routeSpecs(r.routes))
   119  
   120  	metricTags := toMetricTags(params)
   121  
   122  	// wrap provided handler with a handler that registers the path parameter information in the context
   123  	r.impl.Register(method, pathTemplate.Segments(), http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   124  		// register path parameters in context
   125  		pathParamVals := r.impl.PathParams(req, pathVarNames)
   126  		req = req.WithContext(context.WithValue(req.Context(), pathParamsContextKey, pathParamVals))
   127  
   128  		wrappedHandlerFn := createRouteRequestHandler(func(rw http.ResponseWriter, r *http.Request, reqVals RequestVals) {
   129  			handler.ServeHTTP(rw, r)
   130  		}, r.routeHandlers)
   131  
   132  		wrappedHandlerFn(w, req, RequestVals{
   133  			Spec:          routeSpec,
   134  			PathParamVals: pathParamVals,
   135  			ParamPerms:    toRequestParamPerms(params),
   136  			MetricTags:    metricTags,
   137  		})
   138  	}))
   139  	return nil
   140  }
   141  
   142  func (r *routeRequestHandlerWithNext) HandleRequest(rw http.ResponseWriter, req *http.Request, reqVals RequestVals) {
   143  	r.handler(rw, req, reqVals, r.next)
   144  }
   145  
   146  func (r *rootRouter) RegisteredRoutes() []RouteSpec {
   147  	ris := make([]RouteSpec, len(r.routes))
   148  	copy(ris, r.routes)
   149  	return ris
   150  }
   151  
   152  func (r *rootRouter) Get(path string, handler http.Handler, params ...RouteParam) error {
   153  	return r.Register(http.MethodGet, path, handler, params...)
   154  }
   155  
   156  func (r *rootRouter) Head(path string, handler http.Handler, params ...RouteParam) error {
   157  	return r.Register(http.MethodHead, path, handler, params...)
   158  }
   159  
   160  func (r *rootRouter) Post(path string, handler http.Handler, params ...RouteParam) error {
   161  	return r.Register(http.MethodPost, path, handler, params...)
   162  }
   163  
   164  func (r *rootRouter) Put(path string, handler http.Handler, params ...RouteParam) error {
   165  	return r.Register(http.MethodPut, path, handler, params...)
   166  }
   167  
   168  func (r *rootRouter) Patch(path string, handler http.Handler, params ...RouteParam) error {
   169  	return r.Register(http.MethodPatch, path, handler, params...)
   170  }
   171  
   172  func (r *rootRouter) Delete(path string, handler http.Handler, params ...RouteParam) error {
   173  	return r.Register(http.MethodDelete, path, handler, params...)
   174  }
   175  
   176  func (r *rootRouter) Subrouter(path string, params ...RouteParam) Router {
   177  	return &subrouter{
   178  		rPath:   path,
   179  		rParent: r,
   180  		params:  params,
   181  	}
   182  }
   183  
   184  func (r *rootRouter) Path() string {
   185  	return ""
   186  }
   187  
   188  func (r *rootRouter) Parent() Router {
   189  	return nil
   190  }
   191  
   192  func (r *rootRouter) RootRouter() RootRouter {
   193  	return r
   194  }
   195  
   196  func (r *rootRouter) AddRequestHandlerMiddleware(handlers ...RequestHandlerMiddleware) {
   197  	r.reqHandlers = append(r.reqHandlers, handlers...)
   198  	r.updateCachedHandler()
   199  }
   200  
   201  func (r *rootRouter) AddRouteHandlerMiddleware(handlers ...RouteHandlerMiddleware) {
   202  	r.routeHandlers = append(r.routeHandlers, handlers...)
   203  }
   204  
   205  type requestHandlerWithNext struct {
   206  	handler RequestHandlerMiddleware
   207  	next    http.Handler
   208  }
   209  
   210  func createRequestHandler(baseHandler http.Handler, handlers []RequestHandlerMiddleware) http.Handler {
   211  	if len(handlers) == 0 {
   212  		return baseHandler
   213  	}
   214  	return &requestHandlerWithNext{
   215  		handler: handlers[0],
   216  		next:    createRequestHandler(baseHandler, handlers[1:]),
   217  	}
   218  }
   219  
   220  func (r *requestHandlerWithNext) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
   221  	r.handler(rw, req, r.next)
   222  }
   223  
   224  type routeRequestHandlerWithNext struct {
   225  	handler RouteHandlerMiddleware
   226  	next    RouteRequestHandler
   227  }
   228  
   229  func createRouteRequestHandler(baseHandler RouteRequestHandler, handlers []RouteHandlerMiddleware) RouteRequestHandler {
   230  	if len(handlers) == 0 {
   231  		return baseHandler
   232  	}
   233  	return (&routeRequestHandlerWithNext{
   234  		handler: handlers[0],
   235  		next:    createRouteRequestHandler(baseHandler, handlers[1:]),
   236  	}).HandleRequest
   237  }