github.com/blend/go-sdk@v1.20220411.3/web/route_tree_test.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package web 9 10 import ( 11 "fmt" 12 "net/http" 13 "net/http/httptest" 14 "net/url" 15 "strings" 16 "sync/atomic" 17 "testing" 18 19 "github.com/blend/go-sdk/assert" 20 "github.com/blend/go-sdk/webutil" 21 ) 22 23 func handlerNoOp(rw http.ResponseWriter, _ *http.Request, _ *Route, _ RouteParameters) { 24 rw.WriteHeader(http.StatusOK) 25 fmt.Fprintf(rw, "OK!\n") 26 } 27 28 func Test_RouteTree_allowed(t *testing.T) { 29 its := assert.New(t) 30 31 rt := new(RouteTree) 32 rt.Handle(http.MethodGet, "/test", nil) 33 34 allowed := strings.Split(rt.allowed("*", ""), ", ") 35 its.Len(allowed, 1) 36 its.Equal("GET", allowed[0]) 37 38 rt.Handle(http.MethodPost, "/hello", nil) 39 allowed = strings.Split(rt.allowed("*", ""), ", ") 40 its.Len(allowed, 2) 41 its.Any(allowed, func(i interface{}) bool { 42 s, ok := i.(string) 43 return ok && s == http.MethodGet 44 }) 45 its.Any(allowed, func(i interface{}) bool { 46 s, ok := i.(string) 47 return ok && s == http.MethodPost 48 }) 49 50 rt = new(RouteTree) 51 52 rt.Handle(http.MethodGet, "/hello", handlerNoOp) 53 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 54 its.Len(allowed, 2) 55 its.Any(allowed, func(i interface{}) bool { 56 s, ok := i.(string) 57 return ok && s == "GET" 58 }) 59 its.Any(allowed, func(i interface{}) bool { 60 s, ok := i.(string) 61 return ok && s == "OPTIONS" 62 }) 63 rt.Handle(http.MethodPost, "/hello", handlerNoOp) 64 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 65 its.Len(allowed, 3) 66 67 rt.Handle(http.MethodOptions, "/hello", handlerNoOp) 68 rt.Handle(http.MethodHead, "/hello", handlerNoOp) 69 rt.Handle(http.MethodPut, "/hello", handlerNoOp) 70 rt.Handle(http.MethodDelete, "/hello", handlerNoOp) 71 72 rt.Handle(http.MethodPatch, "/hi", handlerNoOp) 73 rt.Handle(http.MethodPatch, "/there", handlerNoOp) 74 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 75 its.Len(allowed, 6) 76 77 rt.Handle(http.MethodPatch, "/hello", handlerNoOp) 78 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 79 its.Len(allowed, 7) 80 } 81 82 func Test_RouteTree_Route(t *testing.T) { 83 its := assert.New(t) 84 85 rt := new(RouteTree) 86 87 rt.Handle(http.MethodGet, "/", handlerNoOp) 88 rt.Handle(http.MethodGet, "/foo", handlerNoOp) 89 rt.Handle(http.MethodGet, "/foo/:id", handlerNoOp) 90 rt.Handle(http.MethodPost, "/foo", handlerNoOp) 91 rt.Handle(http.MethodGet, "/bar", handlerNoOp) 92 93 // explicitly register a slash suffixed url here 94 rt.Handle(http.MethodGet, "/slash/", handlerNoOp) 95 96 req := &http.Request{ 97 Method: http.MethodGet, 98 URL: &url.URL{ 99 Path: "/", 100 }, 101 } 102 route, params := rt.Route(req) 103 its.NotNil(route) 104 its.Equal("/", route.Path) 105 its.Empty(params) 106 its.Equal("/", req.URL.Path) 107 108 req = &http.Request{ 109 Method: http.MethodGet, 110 URL: &url.URL{ 111 Path: "/foo", 112 }, 113 } 114 route, params = rt.Route(req) 115 its.NotNil(route) 116 its.Equal("/foo", route.Path) 117 its.Equal(http.MethodGet, route.Method) 118 its.Empty(params) 119 its.Equal("/foo", req.URL.Path) 120 121 req = &http.Request{ 122 Method: http.MethodPost, 123 URL: &url.URL{ 124 Path: "/foo", 125 }, 126 } 127 route, params = rt.Route(req) 128 its.NotNil(route) 129 its.Equal("/foo", route.Path) 130 its.Equal(http.MethodPost, route.Method) 131 its.Empty(params) 132 its.Equal("/foo", req.URL.Path) 133 134 // explicitly test matching with an extra slash 135 req = &http.Request{ 136 Method: http.MethodGet, 137 URL: &url.URL{ 138 Path: "/foo/", 139 }, 140 } 141 route, params = rt.Route(req) 142 its.NotNil(route) 143 its.Equal("/foo", route.Path) 144 its.Empty(params) 145 its.Equal("/foo/", req.URL.Path) 146 147 req = &http.Request{ 148 Method: http.MethodGet, 149 URL: &url.URL{ 150 Path: "/foo/test", 151 }, 152 } 153 route, params = rt.Route(req) 154 its.NotNil(route) 155 its.Equal("/foo/:id", route.Path) 156 its.NotEmpty(params) 157 its.Equal("test", params["id"]) 158 its.Equal("/foo/test", req.URL.Path) 159 160 req = &http.Request{ 161 Method: http.MethodGet, 162 URL: &url.URL{ 163 Path: "/bar", 164 }, 165 } 166 route, params = rt.Route(req) 167 its.NotNil(route) 168 its.Equal("/bar", route.Path) 169 its.Empty(params) 170 its.Equal("/bar", req.URL.Path) 171 172 req = &http.Request{ 173 Method: http.MethodGet, 174 URL: &url.URL{ 175 Path: "/slash", 176 }, 177 } 178 route, params = rt.Route(req) 179 its.NotNil(route) 180 its.Equal("/slash/", route.Path) 181 its.Empty(params) 182 its.Equal("/slash", req.URL.Path) 183 184 req = &http.Request{ 185 Method: http.MethodConnect, 186 URL: &url.URL{ 187 Path: "/slash", 188 }, 189 } 190 route, params = rt.Route(req) 191 its.Nil(route) 192 its.Empty(params) 193 its.Equal("/slash", req.URL.Path) 194 195 req = &http.Request{ 196 Method: http.MethodGet, 197 URL: &url.URL{ 198 Path: "/slash", 199 }, 200 } 201 rt.SkipTrailingSlashRedirects = true 202 route, params = rt.Route(req) 203 its.Nil(route) 204 its.Empty(params) 205 its.Equal("/slash", req.URL.Path) 206 } 207 208 func Test_RouteTree_Route_slash(t *testing.T) { 209 its := assert.New(t) 210 211 rt := new(RouteTree) 212 213 req := &http.Request{ 214 Method: http.MethodGet, 215 URL: &url.URL{ 216 Path: "/", 217 }, 218 } 219 route, params := rt.Route(req) 220 its.Nil(route) 221 its.Empty(params) 222 its.Equal("/", req.URL.Path) 223 } 224 225 func Test_RouteTree_withPathAlternateTrailingSlash(t *testing.T) { 226 its := assert.New(t) 227 228 its.Equal("/foo", new(RouteTree).withPathAlternateTrailingSlash("/foo/")) 229 its.Equal("/foo/", new(RouteTree).withPathAlternateTrailingSlash("/foo")) 230 its.Equal("", new(RouteTree).withPathAlternateTrailingSlash("")) 231 } 232 233 func routeExpectsPath(method, path string) Handler { 234 return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) { 235 if req.Method != method { 236 http.Error(rw, "expects method: "+method, http.StatusBadRequest) 237 return 238 } 239 if req.URL.Path != path { 240 http.Error(rw, "expects path: "+path, http.StatusBadRequest) 241 return 242 } 243 rw.WriteHeader(http.StatusOK) 244 fmt.Fprintf(rw, "OK!\n") 245 } 246 } 247 248 func callCounter(counter *int32, statusCode int) Handler { 249 return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) { 250 defer atomic.AddInt32(counter, 1) 251 rw.WriteHeader(statusCode) 252 fmt.Fprintf(rw, "counted call!\n") 253 } 254 } 255 256 func Test_RouteTree_ServeHTTP(t *testing.T) { 257 its := assert.New(t) 258 259 rt := new(RouteTree) 260 261 rt.Handle(http.MethodGet, "/", routeExpectsPath(http.MethodGet, "/")) 262 rt.Handle(http.MethodGet, "/foo", routeExpectsPath(http.MethodGet, "/foo")) 263 rt.Handle(http.MethodGet, "/foo/:id", routeExpectsPath(http.MethodGet, "/foo/test-id")) 264 rt.Handle(http.MethodPost, "/foo", routeExpectsPath(http.MethodPost, "/foo")) 265 rt.Handle(http.MethodGet, "/bar", routeExpectsPath(http.MethodGet, "/bar")) 266 267 // explicitly register a slash url here 268 rt.Handle(http.MethodGet, "/slash/", handlerNoOp) 269 270 mock := httptest.NewServer(rt) 271 defer mock.Close() 272 273 res, err := mock.Client().Get(mock.URL + "/") 274 its.Nil(err) 275 its.Equal(http.StatusOK, res.StatusCode) 276 277 res, err = mock.Client().Get(mock.URL + "/foo") 278 its.Nil(err) 279 its.Equal(http.StatusOK, res.StatusCode) 280 281 res, err = mock.Client().Get(mock.URL + "/foo/") 282 its.Nil(err) 283 its.Equal(http.StatusOK, res.StatusCode) 284 285 res, err = mock.Client().Post(mock.URL+"/foo/", "", nil) 286 its.Nil(err) 287 its.Equal(http.StatusOK, res.StatusCode) 288 289 res, err = mock.Client().Get(mock.URL + "/foo/test-id") 290 its.Nil(err) 291 its.Equal(http.StatusOK, res.StatusCode) 292 293 res, err = mock.Client().Get(mock.URL + "/foo/not-test-id") 294 its.Nil(err) 295 its.Equal(http.StatusBadRequest, res.StatusCode) 296 297 res, err = mock.Client().Get(mock.URL + "/bar/") 298 its.Nil(err) 299 its.Equal(http.StatusOK, res.StatusCode) 300 301 optionsReq, _ := http.NewRequest(http.MethodOptions, mock.URL, nil) 302 // now handle the super weird stuff 303 res, err = mock.Client().Do(optionsReq) 304 its.Nil(err) 305 its.Equal(http.StatusOK, res.StatusCode) 306 allowedHeader := res.Header.Get(webutil.HeaderAllow) 307 its.NotEmpty(allowedHeader) 308 its.Equal("GET, OPTIONS", allowedHeader) 309 310 rt.SkipHandlingMethodOptions = true 311 res, err = mock.Client().Do(optionsReq) 312 its.Nil(err) 313 its.Equal(http.StatusNotFound, res.StatusCode) 314 allowedHeader = res.Header.Get(webutil.HeaderAllow) 315 its.Empty(allowedHeader) 316 317 var notFoundCalls int32 318 rt.NotFoundHandler = callCounter(¬FoundCalls, http.StatusNotFound) 319 res, err = mock.Client().Do(optionsReq) 320 its.Nil(err) 321 its.Equal(http.StatusNotFound, res.StatusCode) 322 allowedHeader = res.Header.Get(webutil.HeaderAllow) 323 its.Empty(allowedHeader) 324 its.Equal(1, notFoundCalls) 325 326 headReq, _ := http.NewRequest(http.MethodHead, mock.URL, nil) 327 res, err = mock.Client().Do(headReq) 328 its.Nil(err) 329 its.Equal(http.StatusMethodNotAllowed, res.StatusCode) 330 allowedHeader = res.Header.Get(webutil.HeaderAllow) 331 its.NotEmpty(allowedHeader) 332 its.Equal("GET, OPTIONS", allowedHeader) 333 334 var methodNotAllowedCalls int32 335 rt.MethodNotAllowedHandler = callCounter(&methodNotAllowedCalls, http.StatusMethodNotAllowed) 336 res, err = mock.Client().Do(headReq) 337 its.Nil(err) 338 its.Equal(http.StatusMethodNotAllowed, res.StatusCode) 339 allowedHeader = res.Header.Get(webutil.HeaderAllow) 340 its.NotEmpty(allowedHeader) 341 its.Equal("GET, OPTIONS", allowedHeader) 342 its.Equal(1, notFoundCalls) 343 its.Equal(1, methodNotAllowedCalls) 344 345 rt.SkipMethodNotAllowed = true 346 rt.NotFoundHandler = nil 347 res, err = mock.Client().Do(optionsReq) 348 its.Nil(err) 349 its.Equal(http.StatusNotFound, res.StatusCode) 350 allowedHeader = res.Header.Get(webutil.HeaderAllow) 351 its.Empty(allowedHeader) 352 353 rt.NotFoundHandler = callCounter(¬FoundCalls, http.StatusNotFound) 354 res, err = mock.Client().Do(optionsReq) 355 its.Nil(err) 356 its.Equal(http.StatusNotFound, res.StatusCode) 357 allowedHeader = res.Header.Get(webutil.HeaderAllow) 358 its.Empty(allowedHeader) 359 its.Equal(2, notFoundCalls) 360 }