go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/web/route_tree_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package web
     9  
    10  import (
    11  	"fmt"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"strings"
    16  	"sync/atomic"
    17  	"testing"
    18  
    19  	. "go.charczuk.com/sdk/assert"
    20  )
    21  
    22  func handlerNoOp(rw http.ResponseWriter, _ *http.Request, _ *Route, _ RouteParameters) {
    23  	rw.WriteHeader(http.StatusOK)
    24  	fmt.Fprintf(rw, "OK!\n")
    25  }
    26  
    27  func Test_RouteTree_allowed(t *testing.T) {
    28  	rt := new(RouteTree)
    29  	rt.Handle(http.MethodGet, "/test", nil)
    30  
    31  	allowed := strings.Split(rt.allowed("*", ""), ", ")
    32  	ItsLen(t, allowed, 1)
    33  	ItsEqual(t, "GET", allowed[0])
    34  
    35  	rt.Handle(http.MethodPost, "/hello", nil)
    36  	allowed = strings.Split(rt.allowed("*", ""), ", ")
    37  	ItsLen(t, allowed, 2)
    38  	ItsAny(t, allowed, func(s string) bool {
    39  		return s == http.MethodGet
    40  	})
    41  	ItsAny(t, allowed, func(s string) bool {
    42  		return s == http.MethodPost
    43  	})
    44  
    45  	rt = new(RouteTree)
    46  
    47  	rt.Handle(http.MethodGet, "/hello", handlerNoOp)
    48  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    49  	ItsLen(t, allowed, 2)
    50  	ItsAny(t, allowed, func(s string) bool {
    51  		return s == "GET"
    52  	})
    53  	ItsAny(t, allowed, func(s string) bool {
    54  		return s == "OPTIONS"
    55  	})
    56  	rt.Handle(http.MethodPost, "/hello", handlerNoOp)
    57  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    58  	ItsLen(t, allowed, 3)
    59  
    60  	rt.Handle(http.MethodOptions, "/hello", handlerNoOp)
    61  	rt.Handle(http.MethodHead, "/hello", handlerNoOp)
    62  	rt.Handle(http.MethodPut, "/hello", handlerNoOp)
    63  	rt.Handle(http.MethodDelete, "/hello", handlerNoOp)
    64  
    65  	rt.Handle(http.MethodPatch, "/hi", handlerNoOp)
    66  	rt.Handle(http.MethodPatch, "/there", handlerNoOp)
    67  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    68  	ItsLen(t, allowed, 6)
    69  
    70  	rt.Handle(http.MethodPatch, "/hello", handlerNoOp)
    71  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    72  	ItsLen(t, allowed, 7)
    73  }
    74  
    75  func Test_RouteTree_Route(t *testing.T) {
    76  	rt := new(RouteTree)
    77  
    78  	rt.Handle(http.MethodGet, "/", handlerNoOp)
    79  	rt.Handle(http.MethodGet, "/foo", handlerNoOp)
    80  	rt.Handle(http.MethodGet, "/foo/:id", handlerNoOp)
    81  	rt.Handle(http.MethodPost, "/foo", handlerNoOp)
    82  	rt.Handle(http.MethodGet, "/bar", handlerNoOp)
    83  
    84  	// explicitly register a slash suffixed url here
    85  	rt.Handle(http.MethodGet, "/slash/", handlerNoOp)
    86  
    87  	req := &http.Request{
    88  		Method: http.MethodGet,
    89  		URL: &url.URL{
    90  			Path: "/",
    91  		},
    92  	}
    93  	route, params := rt.Route(req)
    94  	ItsNotNil(t, route)
    95  	ItsEqual(t, "/", route.Path)
    96  	ItsEmpty(t, params)
    97  	ItsEqual(t, "/", req.URL.Path)
    98  
    99  	req = &http.Request{
   100  		Method: http.MethodGet,
   101  		URL: &url.URL{
   102  			Path: "/foo",
   103  		},
   104  	}
   105  	route, params = rt.Route(req)
   106  	ItsNotNil(t, route)
   107  	ItsEqual(t, "/foo", route.Path)
   108  	ItsEqual(t, http.MethodGet, route.Method)
   109  	ItsEmpty(t, params)
   110  	ItsEqual(t, "/foo", req.URL.Path)
   111  
   112  	req = &http.Request{
   113  		Method: http.MethodPost,
   114  		URL: &url.URL{
   115  			Path: "/foo",
   116  		},
   117  	}
   118  	route, params = rt.Route(req)
   119  	ItsNotNil(t, route)
   120  	ItsEqual(t, "/foo", route.Path)
   121  	ItsEqual(t, http.MethodPost, route.Method)
   122  	ItsEmpty(t, params)
   123  	ItsEqual(t, "/foo", req.URL.Path)
   124  
   125  	// explicitly test matching with an extra slash
   126  	req = &http.Request{
   127  		Method: http.MethodGet,
   128  		URL: &url.URL{
   129  			Path: "/foo/",
   130  		},
   131  	}
   132  	route, params = rt.Route(req)
   133  	ItsNotNil(t, route)
   134  	ItsEqual(t, "/foo", route.Path)
   135  	ItsEmpty(t, params)
   136  	ItsEqual(t, "/foo/", req.URL.Path)
   137  
   138  	req = &http.Request{
   139  		Method: http.MethodGet,
   140  		URL: &url.URL{
   141  			Path: "/foo/test",
   142  		},
   143  	}
   144  	route, params = rt.Route(req)
   145  	ItsNotNil(t, route)
   146  	ItsEqual(t, "/foo/:id", route.Path)
   147  	ItsNotEmpty(t, params)
   148  	id, _ := params.Get("id")
   149  	ItsEqual(t, "test", id)
   150  	ItsEqual(t, "/foo/test", req.URL.Path)
   151  
   152  	req = &http.Request{
   153  		Method: http.MethodGet,
   154  		URL: &url.URL{
   155  			Path: "/bar",
   156  		},
   157  	}
   158  	route, params = rt.Route(req)
   159  	ItsNotNil(t, route)
   160  	ItsEqual(t, "/bar", route.Path)
   161  	ItsEmpty(t, params)
   162  	ItsEqual(t, "/bar", req.URL.Path)
   163  
   164  	req = &http.Request{
   165  		Method: http.MethodGet,
   166  		URL: &url.URL{
   167  			Path: "/slash",
   168  		},
   169  	}
   170  	route, params = rt.Route(req)
   171  	ItsNotNil(t, route)
   172  	ItsEqual(t, "/slash/", route.Path)
   173  	ItsEmpty(t, params)
   174  	ItsEqual(t, "/slash", req.URL.Path)
   175  
   176  	req = &http.Request{
   177  		Method: http.MethodConnect,
   178  		URL: &url.URL{
   179  			Path: "/slash",
   180  		},
   181  	}
   182  	route, params = rt.Route(req)
   183  	ItsNil(t, route)
   184  	ItsEmpty(t, params)
   185  	ItsEqual(t, "/slash", req.URL.Path)
   186  
   187  	req = &http.Request{
   188  		Method: http.MethodGet,
   189  		URL: &url.URL{
   190  			Path: "/slash",
   191  		},
   192  	}
   193  	rt.SkipTrailingSlashRedirects = true
   194  	route, params = rt.Route(req)
   195  	ItsNil(t, route)
   196  	ItsEmpty(t, params)
   197  	ItsEqual(t, "/slash", req.URL.Path)
   198  }
   199  
   200  func Test_RouteTree_Route_slash(t *testing.T) {
   201  	rt := new(RouteTree)
   202  
   203  	req := &http.Request{
   204  		Method: http.MethodGet,
   205  		URL: &url.URL{
   206  			Path: "/",
   207  		},
   208  	}
   209  	route, params := rt.Route(req)
   210  	ItsNil(t, route)
   211  	ItsEmpty(t, params)
   212  	ItsEqual(t, "/", req.URL.Path)
   213  }
   214  
   215  func Test_RouteTree_withPathAlternateTrailingSlash(t *testing.T) {
   216  	ItsEqual(t, "/foo", new(RouteTree).withPathAlternateTrailingSlash("/foo/"))
   217  	ItsEqual(t, "/foo/", new(RouteTree).withPathAlternateTrailingSlash("/foo"))
   218  	ItsEqual(t, "", new(RouteTree).withPathAlternateTrailingSlash(""))
   219  }
   220  
   221  func routeExpectsPath(method, path string) Handler {
   222  	return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) {
   223  		if req.Method != method {
   224  			http.Error(rw, "expects method: "+method, http.StatusBadRequest)
   225  			return
   226  		}
   227  		if req.URL.Path != path {
   228  			http.Error(rw, "expects path: "+path, http.StatusBadRequest)
   229  			return
   230  		}
   231  		rw.WriteHeader(http.StatusOK)
   232  		fmt.Fprintf(rw, "OK!\n")
   233  	}
   234  }
   235  
   236  func callCounter(counter *int32, statusCode int) Handler {
   237  	return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) {
   238  		defer atomic.AddInt32(counter, 1)
   239  		rw.WriteHeader(statusCode)
   240  		fmt.Fprintf(rw, "counted call!\n")
   241  	}
   242  }
   243  
   244  func Test_RouteTree_ServeHTTP(t *testing.T) {
   245  	rt := new(RouteTree)
   246  
   247  	rt.Handle(http.MethodGet, "/", routeExpectsPath(http.MethodGet, "/"))
   248  	rt.Handle(http.MethodGet, "/foo", routeExpectsPath(http.MethodGet, "/foo"))
   249  	rt.Handle(http.MethodGet, "/foo/:id", routeExpectsPath(http.MethodGet, "/foo/test-id"))
   250  	rt.Handle(http.MethodPost, "/foo", routeExpectsPath(http.MethodPost, "/foo"))
   251  	rt.Handle(http.MethodGet, "/bar", routeExpectsPath(http.MethodGet, "/bar"))
   252  
   253  	// explicitly register a slash url here
   254  	rt.Handle(http.MethodGet, "/slash/", handlerNoOp)
   255  
   256  	mock := httptest.NewServer(rt)
   257  	defer mock.Close()
   258  
   259  	res, err := mock.Client().Get(mock.URL + "/")
   260  	ItsNil(t, err)
   261  	ItsEqual(t, http.StatusOK, res.StatusCode)
   262  
   263  	res, err = mock.Client().Get(mock.URL + "/foo")
   264  	ItsNil(t, err)
   265  	ItsEqual(t, http.StatusOK, res.StatusCode)
   266  
   267  	res, err = mock.Client().Get(mock.URL + "/foo/")
   268  	ItsNil(t, err)
   269  	ItsEqual(t, http.StatusOK, res.StatusCode)
   270  
   271  	res, err = mock.Client().Post(mock.URL+"/foo/", "", nil)
   272  	ItsNil(t, err)
   273  	ItsEqual(t, http.StatusOK, res.StatusCode)
   274  
   275  	res, err = mock.Client().Get(mock.URL + "/foo/test-id")
   276  	ItsNil(t, err)
   277  	ItsEqual(t, http.StatusOK, res.StatusCode)
   278  
   279  	res, err = mock.Client().Get(mock.URL + "/foo/not-test-id")
   280  	ItsNil(t, err)
   281  	ItsEqual(t, http.StatusBadRequest, res.StatusCode)
   282  
   283  	res, err = mock.Client().Get(mock.URL + "/bar/")
   284  	ItsNil(t, err)
   285  	ItsEqual(t, http.StatusOK, res.StatusCode)
   286  
   287  	optionsReq, _ := http.NewRequest(http.MethodOptions, mock.URL, nil)
   288  	// now handle the super weird stuff
   289  	res, err = mock.Client().Do(optionsReq)
   290  	ItsNil(t, err)
   291  	ItsEqual(t, http.StatusOK, res.StatusCode)
   292  	allowedHeader := res.Header.Get(HeaderAllow)
   293  	ItsNotEmpty(t, allowedHeader)
   294  	ItsEqual(t, "GET, OPTIONS", allowedHeader)
   295  
   296  	rt.SkipHandlingMethodOptions = true
   297  	res, err = mock.Client().Do(optionsReq)
   298  	ItsNil(t, err)
   299  	ItsEqual(t, http.StatusNotFound, res.StatusCode)
   300  	allowedHeader = res.Header.Get(HeaderAllow)
   301  	ItsEmpty(t, allowedHeader)
   302  
   303  	var notFoundCalls int32
   304  	rt.NotFoundHandler = callCounter(&notFoundCalls, http.StatusNotFound)
   305  	res, err = mock.Client().Do(optionsReq)
   306  	ItsNil(t, err)
   307  	ItsEqual(t, http.StatusNotFound, res.StatusCode)
   308  	allowedHeader = res.Header.Get(HeaderAllow)
   309  	ItsEmpty(t, allowedHeader)
   310  	ItsEqual(t, 1, notFoundCalls)
   311  
   312  	headReq, _ := http.NewRequest(http.MethodHead, mock.URL, nil)
   313  	res, err = mock.Client().Do(headReq)
   314  	ItsNil(t, err)
   315  	ItsEqual(t, http.StatusMethodNotAllowed, res.StatusCode)
   316  	allowedHeader = res.Header.Get(HeaderAllow)
   317  	ItsNotEmpty(t, allowedHeader)
   318  	ItsEqual(t, "GET, OPTIONS", allowedHeader)
   319  
   320  	var methodNotAllowedCalls int32
   321  	rt.MethodNotAllowedHandler = callCounter(&methodNotAllowedCalls, http.StatusMethodNotAllowed)
   322  	res, err = mock.Client().Do(headReq)
   323  	ItsNil(t, err)
   324  	ItsEqual(t, http.StatusMethodNotAllowed, res.StatusCode)
   325  	allowedHeader = res.Header.Get(HeaderAllow)
   326  	ItsNotEmpty(t, allowedHeader)
   327  	ItsEqual(t, "GET, OPTIONS", allowedHeader)
   328  	ItsEqual(t, 1, notFoundCalls)
   329  	ItsEqual(t, 1, methodNotAllowedCalls)
   330  
   331  	rt.SkipMethodNotAllowed = true
   332  	rt.NotFoundHandler = nil
   333  	res, err = mock.Client().Do(optionsReq)
   334  	ItsNil(t, err)
   335  	ItsEqual(t, http.StatusNotFound, res.StatusCode)
   336  	allowedHeader = res.Header.Get(HeaderAllow)
   337  	ItsEmpty(t, allowedHeader)
   338  
   339  	rt.NotFoundHandler = callCounter(&notFoundCalls, http.StatusNotFound)
   340  	res, err = mock.Client().Do(optionsReq)
   341  	ItsNil(t, err)
   342  	ItsEqual(t, http.StatusNotFound, res.StatusCode)
   343  	allowedHeader = res.Header.Get(HeaderAllow)
   344  	ItsEmpty(t, allowedHeader)
   345  	ItsEqual(t, 2, notFoundCalls)
   346  }