github.com/blend/go-sdk@v1.20220411.3/web/route_tree_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package web
     9  
    10  import (
    11  	"fmt"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"strings"
    16  	"sync/atomic"
    17  	"testing"
    18  
    19  	"github.com/blend/go-sdk/assert"
    20  	"github.com/blend/go-sdk/webutil"
    21  )
    22  
    23  func handlerNoOp(rw http.ResponseWriter, _ *http.Request, _ *Route, _ RouteParameters) {
    24  	rw.WriteHeader(http.StatusOK)
    25  	fmt.Fprintf(rw, "OK!\n")
    26  }
    27  
    28  func Test_RouteTree_allowed(t *testing.T) {
    29  	its := assert.New(t)
    30  
    31  	rt := new(RouteTree)
    32  	rt.Handle(http.MethodGet, "/test", nil)
    33  
    34  	allowed := strings.Split(rt.allowed("*", ""), ", ")
    35  	its.Len(allowed, 1)
    36  	its.Equal("GET", allowed[0])
    37  
    38  	rt.Handle(http.MethodPost, "/hello", nil)
    39  	allowed = strings.Split(rt.allowed("*", ""), ", ")
    40  	its.Len(allowed, 2)
    41  	its.Any(allowed, func(i interface{}) bool {
    42  		s, ok := i.(string)
    43  		return ok && s == http.MethodGet
    44  	})
    45  	its.Any(allowed, func(i interface{}) bool {
    46  		s, ok := i.(string)
    47  		return ok && s == http.MethodPost
    48  	})
    49  
    50  	rt = new(RouteTree)
    51  
    52  	rt.Handle(http.MethodGet, "/hello", handlerNoOp)
    53  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    54  	its.Len(allowed, 2)
    55  	its.Any(allowed, func(i interface{}) bool {
    56  		s, ok := i.(string)
    57  		return ok && s == "GET"
    58  	})
    59  	its.Any(allowed, func(i interface{}) bool {
    60  		s, ok := i.(string)
    61  		return ok && s == "OPTIONS"
    62  	})
    63  	rt.Handle(http.MethodPost, "/hello", handlerNoOp)
    64  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    65  	its.Len(allowed, 3)
    66  
    67  	rt.Handle(http.MethodOptions, "/hello", handlerNoOp)
    68  	rt.Handle(http.MethodHead, "/hello", handlerNoOp)
    69  	rt.Handle(http.MethodPut, "/hello", handlerNoOp)
    70  	rt.Handle(http.MethodDelete, "/hello", handlerNoOp)
    71  
    72  	rt.Handle(http.MethodPatch, "/hi", handlerNoOp)
    73  	rt.Handle(http.MethodPatch, "/there", handlerNoOp)
    74  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    75  	its.Len(allowed, 6)
    76  
    77  	rt.Handle(http.MethodPatch, "/hello", handlerNoOp)
    78  	allowed = strings.Split(rt.allowed("/hello", ""), ", ")
    79  	its.Len(allowed, 7)
    80  }
    81  
    82  func Test_RouteTree_Route(t *testing.T) {
    83  	its := assert.New(t)
    84  
    85  	rt := new(RouteTree)
    86  
    87  	rt.Handle(http.MethodGet, "/", handlerNoOp)
    88  	rt.Handle(http.MethodGet, "/foo", handlerNoOp)
    89  	rt.Handle(http.MethodGet, "/foo/:id", handlerNoOp)
    90  	rt.Handle(http.MethodPost, "/foo", handlerNoOp)
    91  	rt.Handle(http.MethodGet, "/bar", handlerNoOp)
    92  
    93  	// explicitly register a slash suffixed url here
    94  	rt.Handle(http.MethodGet, "/slash/", handlerNoOp)
    95  
    96  	req := &http.Request{
    97  		Method: http.MethodGet,
    98  		URL: &url.URL{
    99  			Path: "/",
   100  		},
   101  	}
   102  	route, params := rt.Route(req)
   103  	its.NotNil(route)
   104  	its.Equal("/", route.Path)
   105  	its.Empty(params)
   106  	its.Equal("/", req.URL.Path)
   107  
   108  	req = &http.Request{
   109  		Method: http.MethodGet,
   110  		URL: &url.URL{
   111  			Path: "/foo",
   112  		},
   113  	}
   114  	route, params = rt.Route(req)
   115  	its.NotNil(route)
   116  	its.Equal("/foo", route.Path)
   117  	its.Equal(http.MethodGet, route.Method)
   118  	its.Empty(params)
   119  	its.Equal("/foo", req.URL.Path)
   120  
   121  	req = &http.Request{
   122  		Method: http.MethodPost,
   123  		URL: &url.URL{
   124  			Path: "/foo",
   125  		},
   126  	}
   127  	route, params = rt.Route(req)
   128  	its.NotNil(route)
   129  	its.Equal("/foo", route.Path)
   130  	its.Equal(http.MethodPost, route.Method)
   131  	its.Empty(params)
   132  	its.Equal("/foo", req.URL.Path)
   133  
   134  	// explicitly test matching with an extra slash
   135  	req = &http.Request{
   136  		Method: http.MethodGet,
   137  		URL: &url.URL{
   138  			Path: "/foo/",
   139  		},
   140  	}
   141  	route, params = rt.Route(req)
   142  	its.NotNil(route)
   143  	its.Equal("/foo", route.Path)
   144  	its.Empty(params)
   145  	its.Equal("/foo/", req.URL.Path)
   146  
   147  	req = &http.Request{
   148  		Method: http.MethodGet,
   149  		URL: &url.URL{
   150  			Path: "/foo/test",
   151  		},
   152  	}
   153  	route, params = rt.Route(req)
   154  	its.NotNil(route)
   155  	its.Equal("/foo/:id", route.Path)
   156  	its.NotEmpty(params)
   157  	its.Equal("test", params["id"])
   158  	its.Equal("/foo/test", req.URL.Path)
   159  
   160  	req = &http.Request{
   161  		Method: http.MethodGet,
   162  		URL: &url.URL{
   163  			Path: "/bar",
   164  		},
   165  	}
   166  	route, params = rt.Route(req)
   167  	its.NotNil(route)
   168  	its.Equal("/bar", route.Path)
   169  	its.Empty(params)
   170  	its.Equal("/bar", req.URL.Path)
   171  
   172  	req = &http.Request{
   173  		Method: http.MethodGet,
   174  		URL: &url.URL{
   175  			Path: "/slash",
   176  		},
   177  	}
   178  	route, params = rt.Route(req)
   179  	its.NotNil(route)
   180  	its.Equal("/slash/", route.Path)
   181  	its.Empty(params)
   182  	its.Equal("/slash", req.URL.Path)
   183  
   184  	req = &http.Request{
   185  		Method: http.MethodConnect,
   186  		URL: &url.URL{
   187  			Path: "/slash",
   188  		},
   189  	}
   190  	route, params = rt.Route(req)
   191  	its.Nil(route)
   192  	its.Empty(params)
   193  	its.Equal("/slash", req.URL.Path)
   194  
   195  	req = &http.Request{
   196  		Method: http.MethodGet,
   197  		URL: &url.URL{
   198  			Path: "/slash",
   199  		},
   200  	}
   201  	rt.SkipTrailingSlashRedirects = true
   202  	route, params = rt.Route(req)
   203  	its.Nil(route)
   204  	its.Empty(params)
   205  	its.Equal("/slash", req.URL.Path)
   206  }
   207  
   208  func Test_RouteTree_Route_slash(t *testing.T) {
   209  	its := assert.New(t)
   210  
   211  	rt := new(RouteTree)
   212  
   213  	req := &http.Request{
   214  		Method: http.MethodGet,
   215  		URL: &url.URL{
   216  			Path: "/",
   217  		},
   218  	}
   219  	route, params := rt.Route(req)
   220  	its.Nil(route)
   221  	its.Empty(params)
   222  	its.Equal("/", req.URL.Path)
   223  }
   224  
   225  func Test_RouteTree_withPathAlternateTrailingSlash(t *testing.T) {
   226  	its := assert.New(t)
   227  
   228  	its.Equal("/foo", new(RouteTree).withPathAlternateTrailingSlash("/foo/"))
   229  	its.Equal("/foo/", new(RouteTree).withPathAlternateTrailingSlash("/foo"))
   230  	its.Equal("", new(RouteTree).withPathAlternateTrailingSlash(""))
   231  }
   232  
   233  func routeExpectsPath(method, path string) Handler {
   234  	return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) {
   235  		if req.Method != method {
   236  			http.Error(rw, "expects method: "+method, http.StatusBadRequest)
   237  			return
   238  		}
   239  		if req.URL.Path != path {
   240  			http.Error(rw, "expects path: "+path, http.StatusBadRequest)
   241  			return
   242  		}
   243  		rw.WriteHeader(http.StatusOK)
   244  		fmt.Fprintf(rw, "OK!\n")
   245  	}
   246  }
   247  
   248  func callCounter(counter *int32, statusCode int) Handler {
   249  	return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) {
   250  		defer atomic.AddInt32(counter, 1)
   251  		rw.WriteHeader(statusCode)
   252  		fmt.Fprintf(rw, "counted call!\n")
   253  	}
   254  }
   255  
   256  func Test_RouteTree_ServeHTTP(t *testing.T) {
   257  	its := assert.New(t)
   258  
   259  	rt := new(RouteTree)
   260  
   261  	rt.Handle(http.MethodGet, "/", routeExpectsPath(http.MethodGet, "/"))
   262  	rt.Handle(http.MethodGet, "/foo", routeExpectsPath(http.MethodGet, "/foo"))
   263  	rt.Handle(http.MethodGet, "/foo/:id", routeExpectsPath(http.MethodGet, "/foo/test-id"))
   264  	rt.Handle(http.MethodPost, "/foo", routeExpectsPath(http.MethodPost, "/foo"))
   265  	rt.Handle(http.MethodGet, "/bar", routeExpectsPath(http.MethodGet, "/bar"))
   266  
   267  	// explicitly register a slash url here
   268  	rt.Handle(http.MethodGet, "/slash/", handlerNoOp)
   269  
   270  	mock := httptest.NewServer(rt)
   271  	defer mock.Close()
   272  
   273  	res, err := mock.Client().Get(mock.URL + "/")
   274  	its.Nil(err)
   275  	its.Equal(http.StatusOK, res.StatusCode)
   276  
   277  	res, err = mock.Client().Get(mock.URL + "/foo")
   278  	its.Nil(err)
   279  	its.Equal(http.StatusOK, res.StatusCode)
   280  
   281  	res, err = mock.Client().Get(mock.URL + "/foo/")
   282  	its.Nil(err)
   283  	its.Equal(http.StatusOK, res.StatusCode)
   284  
   285  	res, err = mock.Client().Post(mock.URL+"/foo/", "", nil)
   286  	its.Nil(err)
   287  	its.Equal(http.StatusOK, res.StatusCode)
   288  
   289  	res, err = mock.Client().Get(mock.URL + "/foo/test-id")
   290  	its.Nil(err)
   291  	its.Equal(http.StatusOK, res.StatusCode)
   292  
   293  	res, err = mock.Client().Get(mock.URL + "/foo/not-test-id")
   294  	its.Nil(err)
   295  	its.Equal(http.StatusBadRequest, res.StatusCode)
   296  
   297  	res, err = mock.Client().Get(mock.URL + "/bar/")
   298  	its.Nil(err)
   299  	its.Equal(http.StatusOK, res.StatusCode)
   300  
   301  	optionsReq, _ := http.NewRequest(http.MethodOptions, mock.URL, nil)
   302  	// now handle the super weird stuff
   303  	res, err = mock.Client().Do(optionsReq)
   304  	its.Nil(err)
   305  	its.Equal(http.StatusOK, res.StatusCode)
   306  	allowedHeader := res.Header.Get(webutil.HeaderAllow)
   307  	its.NotEmpty(allowedHeader)
   308  	its.Equal("GET, OPTIONS", allowedHeader)
   309  
   310  	rt.SkipHandlingMethodOptions = true
   311  	res, err = mock.Client().Do(optionsReq)
   312  	its.Nil(err)
   313  	its.Equal(http.StatusNotFound, res.StatusCode)
   314  	allowedHeader = res.Header.Get(webutil.HeaderAllow)
   315  	its.Empty(allowedHeader)
   316  
   317  	var notFoundCalls int32
   318  	rt.NotFoundHandler = callCounter(&notFoundCalls, http.StatusNotFound)
   319  	res, err = mock.Client().Do(optionsReq)
   320  	its.Nil(err)
   321  	its.Equal(http.StatusNotFound, res.StatusCode)
   322  	allowedHeader = res.Header.Get(webutil.HeaderAllow)
   323  	its.Empty(allowedHeader)
   324  	its.Equal(1, notFoundCalls)
   325  
   326  	headReq, _ := http.NewRequest(http.MethodHead, mock.URL, nil)
   327  	res, err = mock.Client().Do(headReq)
   328  	its.Nil(err)
   329  	its.Equal(http.StatusMethodNotAllowed, res.StatusCode)
   330  	allowedHeader = res.Header.Get(webutil.HeaderAllow)
   331  	its.NotEmpty(allowedHeader)
   332  	its.Equal("GET, OPTIONS", allowedHeader)
   333  
   334  	var methodNotAllowedCalls int32
   335  	rt.MethodNotAllowedHandler = callCounter(&methodNotAllowedCalls, http.StatusMethodNotAllowed)
   336  	res, err = mock.Client().Do(headReq)
   337  	its.Nil(err)
   338  	its.Equal(http.StatusMethodNotAllowed, res.StatusCode)
   339  	allowedHeader = res.Header.Get(webutil.HeaderAllow)
   340  	its.NotEmpty(allowedHeader)
   341  	its.Equal("GET, OPTIONS", allowedHeader)
   342  	its.Equal(1, notFoundCalls)
   343  	its.Equal(1, methodNotAllowedCalls)
   344  
   345  	rt.SkipMethodNotAllowed = true
   346  	rt.NotFoundHandler = nil
   347  	res, err = mock.Client().Do(optionsReq)
   348  	its.Nil(err)
   349  	its.Equal(http.StatusNotFound, res.StatusCode)
   350  	allowedHeader = res.Header.Get(webutil.HeaderAllow)
   351  	its.Empty(allowedHeader)
   352  
   353  	rt.NotFoundHandler = callCounter(&notFoundCalls, http.StatusNotFound)
   354  	res, err = mock.Client().Do(optionsReq)
   355  	its.Nil(err)
   356  	its.Equal(http.StatusNotFound, res.StatusCode)
   357  	allowedHeader = res.Header.Get(webutil.HeaderAllow)
   358  	its.Empty(allowedHeader)
   359  	its.Equal(2, notFoundCalls)
   360  }