github.com/lingyao2333/mo-zero@v1.4.1/rest/server_test.go (about) 1 package rest 2 3 import ( 4 "crypto/tls" 5 "fmt" 6 "io" 7 "net/http" 8 "net/http/httptest" 9 "os" 10 "strings" 11 "sync/atomic" 12 "testing" 13 "time" 14 15 "github.com/lingyao2333/mo-zero/core/conf" 16 "github.com/lingyao2333/mo-zero/core/logx" 17 "github.com/lingyao2333/mo-zero/rest/chain" 18 "github.com/lingyao2333/mo-zero/rest/httpx" 19 "github.com/lingyao2333/mo-zero/rest/internal/cors" 20 "github.com/lingyao2333/mo-zero/rest/router" 21 "github.com/stretchr/testify/assert" 22 ) 23 24 func TestNewServer(t *testing.T) { 25 writer := logx.Reset() 26 defer logx.SetWriter(writer) 27 logx.SetWriter(logx.NewWriter(io.Discard)) 28 29 const configYaml = ` 30 Name: foo 31 Port: 54321 32 ` 33 var cnf RestConf 34 assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) 35 36 tests := []struct { 37 c RestConf 38 opts []RunOption 39 fail bool 40 }{ 41 { 42 c: RestConf{}, 43 opts: []RunOption{WithRouter(mockedRouter{}), WithCors()}, 44 }, 45 { 46 c: cnf, 47 opts: []RunOption{WithRouter(mockedRouter{})}, 48 }, 49 { 50 c: cnf, 51 opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)}, 52 }, 53 { 54 c: cnf, 55 opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})}, 56 }, 57 { 58 c: cnf, 59 opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})}, 60 }, 61 { 62 c: cnf, 63 opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})}, 64 }, 65 } 66 67 for _, test := range tests { 68 var svr *Server 69 var err error 70 if test.fail { 71 _, err = NewServer(test.c, test.opts...) 72 assert.NotNil(t, err) 73 continue 74 } else { 75 svr = MustNewServer(test.c, test.opts...) 76 } 77 78 svr.Use(ToMiddleware(func(next http.Handler) http.Handler { 79 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 80 next.ServeHTTP(w, r) 81 }) 82 })) 83 svr.AddRoute(Route{ 84 Method: http.MethodGet, 85 Path: "/", 86 Handler: nil, 87 }, WithJwt("thesecret"), WithSignature(SignatureConf{}), 88 WithJwtTransition("preivous", "thenewone")) 89 90 func() { 91 defer func() { 92 p := recover() 93 switch v := p.(type) { 94 case error: 95 assert.Equal(t, "foo", v.Error()) 96 default: 97 t.Fail() 98 } 99 }() 100 101 svr.Start() 102 svr.Stop() 103 }() 104 } 105 } 106 107 func TestWithMaxBytes(t *testing.T) { 108 const maxBytes = 1000 109 var fr featuredRoutes 110 WithMaxBytes(maxBytes)(&fr) 111 assert.Equal(t, int64(maxBytes), fr.maxBytes) 112 } 113 114 func TestWithMiddleware(t *testing.T) { 115 m := make(map[string]string) 116 rt := router.NewRouter() 117 handler := func(w http.ResponseWriter, r *http.Request) { 118 var v struct { 119 Nickname string `form:"nickname"` 120 Zipcode int64 `form:"zipcode"` 121 } 122 123 err := httpx.Parse(r, &v) 124 assert.Nil(t, err) 125 _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode)) 126 assert.Nil(t, err) 127 } 128 rs := WithMiddleware(func(next http.HandlerFunc) http.HandlerFunc { 129 return func(w http.ResponseWriter, r *http.Request) { 130 var v struct { 131 Name string `path:"name"` 132 Year string `path:"year"` 133 } 134 assert.Nil(t, httpx.ParsePath(r, &v)) 135 m[v.Name] = v.Year 136 next.ServeHTTP(w, r) 137 } 138 }, Route{ 139 Method: http.MethodGet, 140 Path: "/first/:name/:year", 141 Handler: handler, 142 }, Route{ 143 Method: http.MethodGet, 144 Path: "/second/:name/:year", 145 Handler: handler, 146 }) 147 148 urls := []string{ 149 "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000", 150 "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", 151 } 152 for _, route := range rs { 153 assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler)) 154 } 155 for _, url := range urls { 156 r, err := http.NewRequest(http.MethodGet, url, nil) 157 assert.Nil(t, err) 158 159 rr := httptest.NewRecorder() 160 rt.ServeHTTP(rr, r) 161 162 assert.Equal(t, "whatever:200000", rr.Body.String()) 163 } 164 165 assert.EqualValues(t, map[string]string{ 166 "kevin": "2017", 167 "wan": "2020", 168 }, m) 169 } 170 171 func TestMultiMiddlewares(t *testing.T) { 172 m := make(map[string]string) 173 rt := router.NewRouter() 174 handler := func(w http.ResponseWriter, r *http.Request) { 175 var v struct { 176 Nickname string `form:"nickname"` 177 Zipcode int64 `form:"zipcode"` 178 } 179 180 err := httpx.Parse(r, &v) 181 assert.Nil(t, err) 182 _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname])) 183 assert.Nil(t, err) 184 } 185 rs := WithMiddlewares([]Middleware{ 186 func(next http.HandlerFunc) http.HandlerFunc { 187 return func(w http.ResponseWriter, r *http.Request) { 188 var v struct { 189 Name string `path:"name"` 190 Year string `path:"year"` 191 } 192 assert.Nil(t, httpx.ParsePath(r, &v)) 193 m[v.Name] = v.Year 194 next.ServeHTTP(w, r) 195 } 196 }, 197 func(next http.HandlerFunc) http.HandlerFunc { 198 return func(w http.ResponseWriter, r *http.Request) { 199 var v struct { 200 Name string `form:"nickname"` 201 Zipcode string `form:"zipcode"` 202 } 203 assert.Nil(t, httpx.ParseForm(r, &v)) 204 assert.NotEmpty(t, m) 205 m[v.Name] = v.Zipcode + v.Zipcode 206 next.ServeHTTP(w, r) 207 } 208 }, 209 ToMiddleware(func(next http.Handler) http.Handler { 210 return next 211 }), 212 }, Route{ 213 Method: http.MethodGet, 214 Path: "/first/:name/:year", 215 Handler: handler, 216 }, Route{ 217 Method: http.MethodGet, 218 Path: "/second/:name/:year", 219 Handler: handler, 220 }) 221 222 urls := []string{ 223 "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000", 224 "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", 225 } 226 for _, route := range rs { 227 assert.Nil(t, rt.Handle(route.Method, route.Path, route.Handler)) 228 } 229 for _, url := range urls { 230 r, err := http.NewRequest(http.MethodGet, url, nil) 231 assert.Nil(t, err) 232 233 rr := httptest.NewRecorder() 234 rt.ServeHTTP(rr, r) 235 236 assert.Equal(t, "whatever:200000200000", rr.Body.String()) 237 } 238 239 assert.EqualValues(t, map[string]string{ 240 "kevin": "2017", 241 "wan": "2020", 242 "whatever": "200000200000", 243 }, m) 244 } 245 246 func TestWithPrefix(t *testing.T) { 247 fr := featuredRoutes{ 248 routes: []Route{ 249 { 250 Path: "/hello", 251 }, 252 { 253 Path: "/world", 254 }, 255 }, 256 } 257 WithPrefix("/api")(&fr) 258 vals := make([]string, 0, len(fr.routes)) 259 for _, r := range fr.routes { 260 vals = append(vals, r.Path) 261 } 262 assert.EqualValues(t, []string{"/api/hello", "/api/world"}, vals) 263 } 264 265 func TestWithPriority(t *testing.T) { 266 var fr featuredRoutes 267 WithPriority()(&fr) 268 assert.True(t, fr.priority) 269 } 270 271 func TestWithTimeout(t *testing.T) { 272 var fr featuredRoutes 273 WithTimeout(time.Hour)(&fr) 274 assert.Equal(t, time.Hour, fr.timeout) 275 } 276 277 func TestWithTLSConfig(t *testing.T) { 278 const configYaml = ` 279 Name: foo 280 Port: 54321 281 ` 282 var cnf RestConf 283 assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) 284 285 testConfig := &tls.Config{ 286 CipherSuites: []uint16{ 287 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 288 }, 289 } 290 291 testCases := []struct { 292 c RestConf 293 opts []RunOption 294 res *tls.Config 295 }{ 296 { 297 c: cnf, 298 opts: []RunOption{WithTLSConfig(testConfig)}, 299 res: testConfig, 300 }, 301 { 302 c: cnf, 303 opts: []RunOption{WithUnsignedCallback(nil)}, 304 res: nil, 305 }, 306 } 307 308 for _, testCase := range testCases { 309 svr, err := NewServer(testCase.c, testCase.opts...) 310 assert.Nil(t, err) 311 assert.Equal(t, svr.ngin.tlsConfig, testCase.res) 312 } 313 } 314 315 func TestWithCors(t *testing.T) { 316 const configYaml = ` 317 Name: foo 318 Port: 54321 319 ` 320 var cnf RestConf 321 assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) 322 rt := router.NewRouter() 323 svr, err := NewServer(cnf, WithRouter(rt)) 324 assert.Nil(t, err) 325 defer svr.Stop() 326 327 opt := WithCors("local") 328 opt(svr) 329 } 330 331 func TestWithCustomCors(t *testing.T) { 332 const configYaml = ` 333 Name: foo 334 Port: 54321 335 ` 336 var cnf RestConf 337 assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) 338 rt := router.NewRouter() 339 svr, err := NewServer(cnf, WithRouter(rt)) 340 assert.Nil(t, err) 341 342 opt := WithCustomCors(func(header http.Header) { 343 header.Set("foo", "bar") 344 }, func(w http.ResponseWriter) { 345 w.WriteHeader(http.StatusOK) 346 }, "local") 347 opt(svr) 348 } 349 350 func TestServer_PrintRoutes(t *testing.T) { 351 const ( 352 configYaml = ` 353 Name: foo 354 Port: 54321 355 ` 356 expect = `Routes: 357 GET /bar 358 GET /foo 359 GET /foo/:bar 360 GET /foo/:bar/baz 361 ` 362 ) 363 364 var cnf RestConf 365 assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) 366 367 svr, err := NewServer(cnf) 368 assert.Nil(t, err) 369 370 svr.AddRoutes([]Route{ 371 { 372 Method: http.MethodGet, 373 Path: "/foo", 374 Handler: http.NotFound, 375 }, 376 { 377 Method: http.MethodGet, 378 Path: "/bar", 379 Handler: http.NotFound, 380 }, 381 { 382 Method: http.MethodGet, 383 Path: "/foo/:bar", 384 Handler: http.NotFound, 385 }, 386 { 387 Method: http.MethodGet, 388 Path: "/foo/:bar/baz", 389 Handler: http.NotFound, 390 }, 391 }) 392 393 old := os.Stdout 394 r, w, err := os.Pipe() 395 assert.Nil(t, err) 396 os.Stdout = w 397 defer func() { 398 os.Stdout = old 399 }() 400 401 svr.PrintRoutes() 402 ch := make(chan string) 403 404 go func() { 405 var buf strings.Builder 406 io.Copy(&buf, r) 407 ch <- buf.String() 408 }() 409 410 w.Close() 411 out := <-ch 412 assert.Equal(t, expect, out) 413 } 414 415 func TestServer_Routes(t *testing.T) { 416 const ( 417 configYaml = ` 418 Name: foo 419 Port: 54321 420 ` 421 expect = `GET /foo GET /bar GET /foo/:bar GET /foo/:bar/baz` 422 ) 423 424 var cnf RestConf 425 assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) 426 427 svr, err := NewServer(cnf) 428 assert.Nil(t, err) 429 430 svr.AddRoutes([]Route{ 431 { 432 Method: http.MethodGet, 433 Path: "/foo", 434 Handler: http.NotFound, 435 }, 436 { 437 Method: http.MethodGet, 438 Path: "/bar", 439 Handler: http.NotFound, 440 }, 441 { 442 Method: http.MethodGet, 443 Path: "/foo/:bar", 444 Handler: http.NotFound, 445 }, 446 { 447 Method: http.MethodGet, 448 Path: "/foo/:bar/baz", 449 Handler: http.NotFound, 450 }, 451 }) 452 453 routes := svr.Routes() 454 var buf strings.Builder 455 for i := 0; i < len(routes); i++ { 456 buf.WriteString(routes[i].Method) 457 buf.WriteString(" ") 458 buf.WriteString(routes[i].Path) 459 buf.WriteString(" ") 460 } 461 462 assert.Equal(t, expect, strings.Trim(buf.String(), " ")) 463 } 464 465 func TestHandleError(t *testing.T) { 466 assert.NotPanics(t, func() { 467 handleError(nil) 468 handleError(http.ErrServerClosed) 469 }) 470 } 471 472 func TestValidateSecret(t *testing.T) { 473 assert.Panics(t, func() { 474 validateSecret("short") 475 }) 476 } 477 478 func TestServer_WithChain(t *testing.T) { 479 var called int32 480 middleware1 := func() func(http.Handler) http.Handler { 481 return func(next http.Handler) http.Handler { 482 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 483 atomic.AddInt32(&called, 1) 484 next.ServeHTTP(w, r) 485 atomic.AddInt32(&called, 1) 486 }) 487 } 488 } 489 middleware2 := func() func(http.Handler) http.Handler { 490 return func(next http.Handler) http.Handler { 491 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 492 atomic.AddInt32(&called, 1) 493 next.ServeHTTP(w, r) 494 atomic.AddInt32(&called, 1) 495 }) 496 } 497 } 498 499 server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2()))) 500 server.AddRoutes( 501 []Route{ 502 { 503 Method: http.MethodGet, 504 Path: "/", 505 Handler: func(_ http.ResponseWriter, _ *http.Request) { 506 atomic.AddInt32(&called, 1) 507 }, 508 }, 509 }, 510 ) 511 rt := router.NewRouter() 512 assert.Nil(t, server.ngin.bindRoutes(rt)) 513 req, err := http.NewRequest(http.MethodGet, "/", http.NoBody) 514 assert.Nil(t, err) 515 rt.ServeHTTP(httptest.NewRecorder(), req) 516 assert.Equal(t, int32(5), atomic.LoadInt32(&called)) 517 } 518 519 func TestServer_WithCors(t *testing.T) { 520 var called int32 521 middleware := func(next http.Handler) http.Handler { 522 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 523 atomic.AddInt32(&called, 1) 524 next.ServeHTTP(w, r) 525 }) 526 } 527 r := router.NewRouter() 528 assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler()))) 529 530 cr := &corsRouter{ 531 Router: r, 532 middleware: cors.Middleware(nil, "*"), 533 } 534 req := httptest.NewRequest(http.MethodOptions, "/", http.NoBody) 535 cr.ServeHTTP(httptest.NewRecorder(), req) 536 assert.Equal(t, int32(0), atomic.LoadInt32(&called)) 537 }