go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/web/route_tree_test.go (about) 1 /* 2 3 Copyright (c) 2023 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package web 9 10 import ( 11 "fmt" 12 "net/http" 13 "net/http/httptest" 14 "net/url" 15 "strings" 16 "sync/atomic" 17 "testing" 18 19 . "go.charczuk.com/sdk/assert" 20 ) 21 22 func handlerNoOp(rw http.ResponseWriter, _ *http.Request, _ *Route, _ RouteParameters) { 23 rw.WriteHeader(http.StatusOK) 24 fmt.Fprintf(rw, "OK!\n") 25 } 26 27 func Test_RouteTree_allowed(t *testing.T) { 28 rt := new(RouteTree) 29 rt.Handle(http.MethodGet, "/test", nil) 30 31 allowed := strings.Split(rt.allowed("*", ""), ", ") 32 ItsLen(t, allowed, 1) 33 ItsEqual(t, "GET", allowed[0]) 34 35 rt.Handle(http.MethodPost, "/hello", nil) 36 allowed = strings.Split(rt.allowed("*", ""), ", ") 37 ItsLen(t, allowed, 2) 38 ItsAny(t, allowed, func(s string) bool { 39 return s == http.MethodGet 40 }) 41 ItsAny(t, allowed, func(s string) bool { 42 return s == http.MethodPost 43 }) 44 45 rt = new(RouteTree) 46 47 rt.Handle(http.MethodGet, "/hello", handlerNoOp) 48 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 49 ItsLen(t, allowed, 2) 50 ItsAny(t, allowed, func(s string) bool { 51 return s == "GET" 52 }) 53 ItsAny(t, allowed, func(s string) bool { 54 return s == "OPTIONS" 55 }) 56 rt.Handle(http.MethodPost, "/hello", handlerNoOp) 57 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 58 ItsLen(t, allowed, 3) 59 60 rt.Handle(http.MethodOptions, "/hello", handlerNoOp) 61 rt.Handle(http.MethodHead, "/hello", handlerNoOp) 62 rt.Handle(http.MethodPut, "/hello", handlerNoOp) 63 rt.Handle(http.MethodDelete, "/hello", handlerNoOp) 64 65 rt.Handle(http.MethodPatch, "/hi", handlerNoOp) 66 rt.Handle(http.MethodPatch, "/there", handlerNoOp) 67 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 68 ItsLen(t, allowed, 6) 69 70 rt.Handle(http.MethodPatch, "/hello", handlerNoOp) 71 allowed = strings.Split(rt.allowed("/hello", ""), ", ") 72 ItsLen(t, allowed, 7) 73 } 74 75 func Test_RouteTree_Route(t *testing.T) { 76 rt := new(RouteTree) 77 78 rt.Handle(http.MethodGet, "/", handlerNoOp) 79 rt.Handle(http.MethodGet, "/foo", handlerNoOp) 80 rt.Handle(http.MethodGet, "/foo/:id", handlerNoOp) 81 rt.Handle(http.MethodPost, "/foo", handlerNoOp) 82 rt.Handle(http.MethodGet, "/bar", handlerNoOp) 83 84 // explicitly register a slash suffixed url here 85 rt.Handle(http.MethodGet, "/slash/", handlerNoOp) 86 87 req := &http.Request{ 88 Method: http.MethodGet, 89 URL: &url.URL{ 90 Path: "/", 91 }, 92 } 93 route, params := rt.Route(req) 94 ItsNotNil(t, route) 95 ItsEqual(t, "/", route.Path) 96 ItsEmpty(t, params) 97 ItsEqual(t, "/", req.URL.Path) 98 99 req = &http.Request{ 100 Method: http.MethodGet, 101 URL: &url.URL{ 102 Path: "/foo", 103 }, 104 } 105 route, params = rt.Route(req) 106 ItsNotNil(t, route) 107 ItsEqual(t, "/foo", route.Path) 108 ItsEqual(t, http.MethodGet, route.Method) 109 ItsEmpty(t, params) 110 ItsEqual(t, "/foo", req.URL.Path) 111 112 req = &http.Request{ 113 Method: http.MethodPost, 114 URL: &url.URL{ 115 Path: "/foo", 116 }, 117 } 118 route, params = rt.Route(req) 119 ItsNotNil(t, route) 120 ItsEqual(t, "/foo", route.Path) 121 ItsEqual(t, http.MethodPost, route.Method) 122 ItsEmpty(t, params) 123 ItsEqual(t, "/foo", req.URL.Path) 124 125 // explicitly test matching with an extra slash 126 req = &http.Request{ 127 Method: http.MethodGet, 128 URL: &url.URL{ 129 Path: "/foo/", 130 }, 131 } 132 route, params = rt.Route(req) 133 ItsNotNil(t, route) 134 ItsEqual(t, "/foo", route.Path) 135 ItsEmpty(t, params) 136 ItsEqual(t, "/foo/", req.URL.Path) 137 138 req = &http.Request{ 139 Method: http.MethodGet, 140 URL: &url.URL{ 141 Path: "/foo/test", 142 }, 143 } 144 route, params = rt.Route(req) 145 ItsNotNil(t, route) 146 ItsEqual(t, "/foo/:id", route.Path) 147 ItsNotEmpty(t, params) 148 id, _ := params.Get("id") 149 ItsEqual(t, "test", id) 150 ItsEqual(t, "/foo/test", req.URL.Path) 151 152 req = &http.Request{ 153 Method: http.MethodGet, 154 URL: &url.URL{ 155 Path: "/bar", 156 }, 157 } 158 route, params = rt.Route(req) 159 ItsNotNil(t, route) 160 ItsEqual(t, "/bar", route.Path) 161 ItsEmpty(t, params) 162 ItsEqual(t, "/bar", req.URL.Path) 163 164 req = &http.Request{ 165 Method: http.MethodGet, 166 URL: &url.URL{ 167 Path: "/slash", 168 }, 169 } 170 route, params = rt.Route(req) 171 ItsNotNil(t, route) 172 ItsEqual(t, "/slash/", route.Path) 173 ItsEmpty(t, params) 174 ItsEqual(t, "/slash", req.URL.Path) 175 176 req = &http.Request{ 177 Method: http.MethodConnect, 178 URL: &url.URL{ 179 Path: "/slash", 180 }, 181 } 182 route, params = rt.Route(req) 183 ItsNil(t, route) 184 ItsEmpty(t, params) 185 ItsEqual(t, "/slash", req.URL.Path) 186 187 req = &http.Request{ 188 Method: http.MethodGet, 189 URL: &url.URL{ 190 Path: "/slash", 191 }, 192 } 193 rt.SkipTrailingSlashRedirects = true 194 route, params = rt.Route(req) 195 ItsNil(t, route) 196 ItsEmpty(t, params) 197 ItsEqual(t, "/slash", req.URL.Path) 198 } 199 200 func Test_RouteTree_Route_slash(t *testing.T) { 201 rt := new(RouteTree) 202 203 req := &http.Request{ 204 Method: http.MethodGet, 205 URL: &url.URL{ 206 Path: "/", 207 }, 208 } 209 route, params := rt.Route(req) 210 ItsNil(t, route) 211 ItsEmpty(t, params) 212 ItsEqual(t, "/", req.URL.Path) 213 } 214 215 func Test_RouteTree_withPathAlternateTrailingSlash(t *testing.T) { 216 ItsEqual(t, "/foo", new(RouteTree).withPathAlternateTrailingSlash("/foo/")) 217 ItsEqual(t, "/foo/", new(RouteTree).withPathAlternateTrailingSlash("/foo")) 218 ItsEqual(t, "", new(RouteTree).withPathAlternateTrailingSlash("")) 219 } 220 221 func routeExpectsPath(method, path string) Handler { 222 return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) { 223 if req.Method != method { 224 http.Error(rw, "expects method: "+method, http.StatusBadRequest) 225 return 226 } 227 if req.URL.Path != path { 228 http.Error(rw, "expects path: "+path, http.StatusBadRequest) 229 return 230 } 231 rw.WriteHeader(http.StatusOK) 232 fmt.Fprintf(rw, "OK!\n") 233 } 234 } 235 236 func callCounter(counter *int32, statusCode int) Handler { 237 return func(rw http.ResponseWriter, req *http.Request, _ *Route, _ RouteParameters) { 238 defer atomic.AddInt32(counter, 1) 239 rw.WriteHeader(statusCode) 240 fmt.Fprintf(rw, "counted call!\n") 241 } 242 } 243 244 func Test_RouteTree_ServeHTTP(t *testing.T) { 245 rt := new(RouteTree) 246 247 rt.Handle(http.MethodGet, "/", routeExpectsPath(http.MethodGet, "/")) 248 rt.Handle(http.MethodGet, "/foo", routeExpectsPath(http.MethodGet, "/foo")) 249 rt.Handle(http.MethodGet, "/foo/:id", routeExpectsPath(http.MethodGet, "/foo/test-id")) 250 rt.Handle(http.MethodPost, "/foo", routeExpectsPath(http.MethodPost, "/foo")) 251 rt.Handle(http.MethodGet, "/bar", routeExpectsPath(http.MethodGet, "/bar")) 252 253 // explicitly register a slash url here 254 rt.Handle(http.MethodGet, "/slash/", handlerNoOp) 255 256 mock := httptest.NewServer(rt) 257 defer mock.Close() 258 259 res, err := mock.Client().Get(mock.URL + "/") 260 ItsNil(t, err) 261 ItsEqual(t, http.StatusOK, res.StatusCode) 262 263 res, err = mock.Client().Get(mock.URL + "/foo") 264 ItsNil(t, err) 265 ItsEqual(t, http.StatusOK, res.StatusCode) 266 267 res, err = mock.Client().Get(mock.URL + "/foo/") 268 ItsNil(t, err) 269 ItsEqual(t, http.StatusOK, res.StatusCode) 270 271 res, err = mock.Client().Post(mock.URL+"/foo/", "", nil) 272 ItsNil(t, err) 273 ItsEqual(t, http.StatusOK, res.StatusCode) 274 275 res, err = mock.Client().Get(mock.URL + "/foo/test-id") 276 ItsNil(t, err) 277 ItsEqual(t, http.StatusOK, res.StatusCode) 278 279 res, err = mock.Client().Get(mock.URL + "/foo/not-test-id") 280 ItsNil(t, err) 281 ItsEqual(t, http.StatusBadRequest, res.StatusCode) 282 283 res, err = mock.Client().Get(mock.URL + "/bar/") 284 ItsNil(t, err) 285 ItsEqual(t, http.StatusOK, res.StatusCode) 286 287 optionsReq, _ := http.NewRequest(http.MethodOptions, mock.URL, nil) 288 // now handle the super weird stuff 289 res, err = mock.Client().Do(optionsReq) 290 ItsNil(t, err) 291 ItsEqual(t, http.StatusOK, res.StatusCode) 292 allowedHeader := res.Header.Get(HeaderAllow) 293 ItsNotEmpty(t, allowedHeader) 294 ItsEqual(t, "GET, OPTIONS", allowedHeader) 295 296 rt.SkipHandlingMethodOptions = true 297 res, err = mock.Client().Do(optionsReq) 298 ItsNil(t, err) 299 ItsEqual(t, http.StatusNotFound, res.StatusCode) 300 allowedHeader = res.Header.Get(HeaderAllow) 301 ItsEmpty(t, allowedHeader) 302 303 var notFoundCalls int32 304 rt.NotFoundHandler = callCounter(¬FoundCalls, http.StatusNotFound) 305 res, err = mock.Client().Do(optionsReq) 306 ItsNil(t, err) 307 ItsEqual(t, http.StatusNotFound, res.StatusCode) 308 allowedHeader = res.Header.Get(HeaderAllow) 309 ItsEmpty(t, allowedHeader) 310 ItsEqual(t, 1, notFoundCalls) 311 312 headReq, _ := http.NewRequest(http.MethodHead, mock.URL, nil) 313 res, err = mock.Client().Do(headReq) 314 ItsNil(t, err) 315 ItsEqual(t, http.StatusMethodNotAllowed, res.StatusCode) 316 allowedHeader = res.Header.Get(HeaderAllow) 317 ItsNotEmpty(t, allowedHeader) 318 ItsEqual(t, "GET, OPTIONS", allowedHeader) 319 320 var methodNotAllowedCalls int32 321 rt.MethodNotAllowedHandler = callCounter(&methodNotAllowedCalls, http.StatusMethodNotAllowed) 322 res, err = mock.Client().Do(headReq) 323 ItsNil(t, err) 324 ItsEqual(t, http.StatusMethodNotAllowed, res.StatusCode) 325 allowedHeader = res.Header.Get(HeaderAllow) 326 ItsNotEmpty(t, allowedHeader) 327 ItsEqual(t, "GET, OPTIONS", allowedHeader) 328 ItsEqual(t, 1, notFoundCalls) 329 ItsEqual(t, 1, methodNotAllowedCalls) 330 331 rt.SkipMethodNotAllowed = true 332 rt.NotFoundHandler = nil 333 res, err = mock.Client().Do(optionsReq) 334 ItsNil(t, err) 335 ItsEqual(t, http.StatusNotFound, res.StatusCode) 336 allowedHeader = res.Header.Get(HeaderAllow) 337 ItsEmpty(t, allowedHeader) 338 339 rt.NotFoundHandler = callCounter(¬FoundCalls, http.StatusNotFound) 340 res, err = mock.Client().Do(optionsReq) 341 ItsNil(t, err) 342 ItsEqual(t, http.StatusNotFound, res.StatusCode) 343 allowedHeader = res.Header.Get(HeaderAllow) 344 ItsEmpty(t, allowedHeader) 345 ItsEqual(t, 2, notFoundCalls) 346 }