github.com/cloudwego/hertz@v0.9.3/pkg/route/engine_test.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * The MIT License (MIT) 16 * 17 * Copyright (c) 2014 Manuel Martínez-Almeida 18 * 19 * Permission is hereby granted, free of charge, to any person obtaining a copy 20 * of this software and associated documentation files (the "Software"), to deal 21 * in the Software without restriction, including without limitation the rights 22 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 * copies of the Software, and to permit persons to whom the Software is 24 * furnished to do so, subject to the following conditions: 25 * 26 * The above copyright notice and this permission notice shall be included in 27 * all copies or substantial portions of the Software. 28 * 29 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 35 * THE SOFTWARE. 36 * 37 * This file may have been modified by CloudWeGo authors. All CloudWeGo 38 * Modifications are Copyright 2022 CloudWeGo Authors 39 */ 40 41 package route 42 43 import ( 44 "context" 45 "crypto/tls" 46 "errors" 47 "fmt" 48 "html/template" 49 "io/ioutil" 50 "net" 51 "net/http" 52 "sync/atomic" 53 "testing" 54 "time" 55 56 "github.com/cloudwego/hertz/pkg/app" 57 "github.com/cloudwego/hertz/pkg/app/server/binding" 58 "github.com/cloudwego/hertz/pkg/app/server/registry" 59 "github.com/cloudwego/hertz/pkg/common/config" 60 errs "github.com/cloudwego/hertz/pkg/common/errors" 61 "github.com/cloudwego/hertz/pkg/common/test/assert" 62 "github.com/cloudwego/hertz/pkg/common/test/mock" 63 "github.com/cloudwego/hertz/pkg/network" 64 "github.com/cloudwego/hertz/pkg/network/standard" 65 "github.com/cloudwego/hertz/pkg/protocol" 66 "github.com/cloudwego/hertz/pkg/protocol/consts" 67 "github.com/cloudwego/hertz/pkg/protocol/suite" 68 "github.com/cloudwego/hertz/pkg/route/param" 69 ) 70 71 func TestNew_Engine(t *testing.T) { 72 defaultTransporter = standard.NewTransporter 73 opt := config.NewOptions([]config.Option{}) 74 router := NewEngine(opt) 75 assert.DeepEqual(t, "standard", router.GetTransporterName()) 76 assert.DeepEqual(t, "/", router.basePath) 77 assert.DeepEqual(t, router.engine, router) 78 assert.DeepEqual(t, 0, len(router.Handlers)) 79 } 80 81 func TestNew_Engine_WithTransporter(t *testing.T) { 82 defaultTransporter = newMockTransporter 83 opt := config.NewOptions([]config.Option{}) 84 router := NewEngine(opt) 85 assert.DeepEqual(t, "route", router.GetTransporterName()) 86 87 defaultTransporter = newMockTransporter 88 opt.TransporterNewer = standard.NewTransporter 89 router = NewEngine(opt) 90 assert.DeepEqual(t, "standard", router.GetTransporterName()) 91 assert.DeepEqual(t, "route", GetTransporterName()) 92 } 93 94 func TestGetTransporterName(t *testing.T) { 95 name := getTransporterName(&fakeTransporter{}) 96 assert.DeepEqual(t, "route", name) 97 } 98 99 func TestEngineUnescape(t *testing.T) { 100 e := NewEngine(config.NewOptions(nil)) 101 102 routes := []string{ 103 "/*all", 104 "/cmd/:tool/", 105 "/src/*filepath", 106 "/search/:query", 107 "/info/:user/project/:project", 108 "/info/:user", 109 } 110 111 for _, r := range routes { 112 e.GET(r, func(c context.Context, ctx *app.RequestContext) { 113 ctx.String(consts.StatusOK, ctx.Param(ctx.Query("key"))) 114 }) 115 } 116 117 testRoutes := []struct { 118 route string 119 key string 120 want string 121 }{ 122 {"/", "", ""}, 123 {"/cmd/%E4%BD%A0%E5%A5%BD/", "tool", "你好"}, 124 {"/src/some/%E4%B8%96%E7%95%8C.png", "filepath", "some/世界.png"}, 125 {"/info/%E4%BD%A0%E5%A5%BD/project/%E4%B8%96%E7%95%8C", "user", "你好"}, 126 {"/info/%E4%BD%A0%E5%A5%BD/project/%E4%B8%96%E7%95%8C", "project", "世界"}, 127 } 128 for _, tr := range testRoutes { 129 w := performRequest(e, http.MethodGet, tr.route+"?key="+tr.key) 130 assert.DeepEqual(t, consts.StatusOK, w.Code) 131 assert.DeepEqual(t, tr.want, w.Body.String()) 132 } 133 } 134 135 func TestEngineUnescapeRaw(t *testing.T) { 136 e := NewEngine(config.NewOptions(nil)) 137 e.options.UseRawPath = true 138 139 routes := []string{ 140 "/*all", 141 "/cmd/:tool/", 142 "/src/*filepath", 143 "/search/:query", 144 "/info/:user/project/:project", 145 "/info/:user", 146 } 147 148 for _, r := range routes { 149 e.GET(r, func(c context.Context, ctx *app.RequestContext) { 150 ctx.String(consts.StatusOK, ctx.Param(ctx.Query("key"))) 151 }) 152 } 153 154 testRoutes := []struct { 155 route string 156 key string 157 want string 158 }{ 159 {"/", "", ""}, 160 {"/cmd/test/", "tool", "test"}, 161 {"/src/some/file.png", "filepath", "some/file.png"}, 162 {"/src/some/file+test.png", "filepath", "some/file test.png"}, 163 {"/src/some/file++++%%%%test.png", "filepath", "some/file++++%%%%test.png"}, 164 {"/src/some/file%2Ftest.png", "filepath", "some/file/test.png"}, 165 {"/search/someth!ng+in+ünìcodé", "query", "someth!ng in ünìcodé"}, 166 {"/info/gordon/project/go", "user", "gordon"}, 167 {"/info/gordon/project/go", "project", "go"}, 168 {"/info/slash%2Fgordon", "user", "slash/gordon"}, 169 {"/info/slash%2Fgordon/project/Project%20%231", "user", "slash/gordon"}, 170 {"/info/slash%2Fgordon/project/Project%20%231", "project", "Project #1"}, 171 {"/info/slash%%%%", "user", "slash%%%%"}, 172 {"/info/slash%%%%2Fgordon/project/Project%%%%20%231", "user", "slash%%%%2Fgordon"}, 173 {"/info/slash%%%%2Fgordon/project/Project%%%%20%231", "project", "Project%%%%20%231"}, 174 } 175 for _, tr := range testRoutes { 176 w := performRequest(e, http.MethodGet, tr.route+"?key="+tr.key) 177 assert.DeepEqual(t, consts.StatusOK, w.Code) 178 assert.DeepEqual(t, tr.want, w.Body.String()) 179 } 180 } 181 182 func TestConnectionClose(t *testing.T) { 183 engine := NewEngine(config.NewOptions(nil)) 184 atomic.StoreUint32(&engine.status, statusRunning) 185 engine.Init() 186 engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { 187 ctx.String(consts.StatusOK, "ok") 188 }) 189 conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n") 190 err := engine.Serve(context.Background(), conn) 191 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 192 } 193 194 func TestConnectionClose01(t *testing.T) { 195 engine := NewEngine(config.NewOptions(nil)) 196 atomic.StoreUint32(&engine.status, statusRunning) 197 engine.Init() 198 engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { 199 ctx.SetConnectionClose() 200 ctx.String(consts.StatusOK, "ok") 201 }) 202 conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") 203 err := engine.Serve(context.Background(), conn) 204 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 205 } 206 207 func TestIdleTimeout(t *testing.T) { 208 engine := NewEngine(config.NewOptions(nil)) 209 engine.options.IdleTimeout = 0 210 atomic.StoreUint32(&engine.status, statusRunning) 211 engine.Init() 212 engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { 213 time.Sleep(100 * time.Millisecond) 214 ctx.String(consts.StatusOK, "ok") 215 }) 216 217 conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") 218 219 ch := make(chan error) 220 startCh := make(chan error) 221 go func() { 222 <-startCh 223 ch <- engine.Serve(context.Background(), conn) 224 }() 225 close(startCh) 226 select { 227 case err := <-ch: 228 if err != nil { 229 t.Errorf("err happened: %s", err) 230 } 231 return 232 case <-time.Tick(120 * time.Millisecond): 233 t.Errorf("timeout! should have been finished in 120ms...") 234 } 235 } 236 237 func TestIdleTimeout01(t *testing.T) { 238 engine := NewEngine(config.NewOptions(nil)) 239 engine.options.IdleTimeout = 1 * time.Second 240 atomic.StoreUint32(&engine.status, statusRunning) 241 engine.Init() 242 atomic.StoreUint32(&engine.status, statusRunning) 243 engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { 244 time.Sleep(10 * time.Millisecond) 245 ctx.String(consts.StatusOK, "ok") 246 }) 247 248 conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") 249 250 ch := make(chan error) 251 startCh := make(chan error) 252 go func() { 253 <-startCh 254 ch <- engine.Serve(context.Background(), conn) 255 }() 256 close(startCh) 257 select { 258 case <-ch: 259 t.Errorf("cannot return this early! should wait for at least 1s...") 260 case <-time.Tick(1 * time.Second): 261 return 262 } 263 } 264 265 func TestIdleTimeout03(t *testing.T) { 266 engine := NewEngine(config.NewOptions(nil)) 267 engine.options.IdleTimeout = 0 268 engine.transport = standard.NewTransporter(engine.options) 269 atomic.StoreUint32(&engine.status, statusRunning) 270 engine.Init() 271 atomic.StoreUint32(&engine.status, statusRunning) 272 engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { 273 time.Sleep(50 * time.Millisecond) 274 ctx.String(consts.StatusOK, "ok") 275 }) 276 277 conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" + 278 "GET /foo HTTP/1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n") 279 280 ch := make(chan error) 281 startCh := make(chan error) 282 go func() { 283 <-startCh 284 ch <- engine.Serve(context.Background(), conn) 285 }() 286 close(startCh) 287 select { 288 case err := <-ch: 289 if !errors.Is(err, errs.ErrShortConnection) { 290 t.Errorf("err should be ErrShortConnection, but got %s", err) 291 } 292 return 293 case <-time.Tick(200 * time.Millisecond): 294 t.Errorf("timeout! should have been finished in 200ms...") 295 } 296 } 297 298 func TestEngine_Routes(t *testing.T) { 299 engine := NewEngine(config.NewOptions(nil)) 300 engine.GET("/", handlerTest1) 301 engine.GET("/user", handlerTest2) 302 engine.GET("/user/:name/*action", handlerTest1) 303 engine.GET("/anonymous1", func(c context.Context, ctx *app.RequestContext) {}) // TestEngine_Routes.func1 304 engine.POST("/user", handlerTest2) 305 engine.POST("/user/:name/*action", handlerTest2) 306 engine.POST("/anonymous2", func(c context.Context, ctx *app.RequestContext) {}) // TestEngine_Routes.func2 307 group := engine.Group("/v1") 308 { 309 group.GET("/user", handlerTest1) 310 group.POST("/login", handlerTest2) 311 } 312 engine.Static("/static", ".") 313 314 list := engine.Routes() 315 316 assert.DeepEqual(t, 11, len(list)) 317 318 assertRoutePresent(t, list, RouteInfo{ 319 Method: "GET", 320 Path: "/", 321 Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest1", 322 }) 323 assertRoutePresent(t, list, RouteInfo{ 324 Method: "GET", 325 Path: "/user", 326 Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", 327 }) 328 assertRoutePresent(t, list, RouteInfo{ 329 Method: "GET", 330 Path: "/user/:name/*action", 331 Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest1", 332 }) 333 assertRoutePresent(t, list, RouteInfo{ 334 Method: "GET", 335 Path: "/v1/user", 336 Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest1", 337 }) 338 assertRoutePresent(t, list, RouteInfo{ 339 Method: "GET", 340 Path: "/static/*filepath", 341 Handler: "github.com/cloudwego/hertz/pkg/app.(*fsHandler).handleRequest-fm", 342 }) 343 assertRoutePresent(t, list, RouteInfo{ 344 Method: "GET", 345 Path: "/anonymous1", 346 Handler: "github.com/cloudwego/hertz/pkg/route.TestEngine_Routes.func1", 347 }) 348 assertRoutePresent(t, list, RouteInfo{ 349 Method: "POST", 350 Path: "/user", 351 Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", 352 }) 353 assertRoutePresent(t, list, RouteInfo{ 354 Method: "POST", 355 Path: "/user/:name/*action", 356 Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", 357 }) 358 assertRoutePresent(t, list, RouteInfo{ 359 Method: "POST", 360 Path: "/anonymous2", 361 Handler: "github.com/cloudwego/hertz/pkg/route.TestEngine_Routes.func2", 362 }) 363 assertRoutePresent(t, list, RouteInfo{ 364 Method: "POST", 365 Path: "/v1/login", 366 Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", 367 }) 368 assertRoutePresent(t, list, RouteInfo{ 369 Method: "HEAD", 370 Path: "/static/*filepath", 371 Handler: "github.com/cloudwego/hertz/pkg/app.(*fsHandler).handleRequest-fm", 372 }) 373 } 374 375 func handlerTest1(c context.Context, ctx *app.RequestContext) {} 376 377 func handlerTest2(c context.Context, ctx *app.RequestContext) {} 378 379 func assertRoutePresent(t *testing.T, gets RoutesInfo, want RouteInfo) { 380 for _, get := range gets { 381 if get.Path == want.Path && get.Method == want.Method && get.Handler == want.Handler { 382 return 383 } 384 } 385 386 t.Errorf("route not found: %v", want) 387 } 388 389 func TestGetNextProto(t *testing.T) { 390 e := NewEngine(config.NewOptions(nil)) 391 conn := &mockConn{} 392 proto, err := e.getNextProto(conn) 393 if proto != "h2" { 394 t.Errorf("unexpected proto: %#v, expected: %#v", proto, "h2") 395 } 396 if err != nil { 397 t.Errorf("unexpected error: %s", err.Error()) 398 } 399 } 400 401 func formatAsDate(t time.Time) string { 402 year, month, day := t.Date() 403 return fmt.Sprintf("%d/%02d/%02d", year, month, day) 404 } 405 406 func TestRenderHtml(t *testing.T) { 407 e := NewEngine(config.NewOptions(nil)) 408 e.Delims("{[{", "}]}") 409 e.SetFuncMap(template.FuncMap{ 410 "formatAsDate": formatAsDate, 411 }) 412 e.LoadHTMLGlob("../common/testdata/template/htmltemplate.html") 413 e.GET("/templateName", func(c context.Context, ctx *app.RequestContext) { 414 ctx.HTML(http.StatusOK, "htmltemplate.html", map[string]interface{}{ 415 "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), 416 }) 417 }) 418 rr := performRequest(e, "GET", "/templateName") 419 b, _ := ioutil.ReadAll(rr.Body) 420 assert.DeepEqual(t, consts.StatusOK, rr.Code) 421 assert.DeepEqual(t, []byte("<h1>Date: 2017/07/01</h1>"), b) 422 assert.DeepEqual(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) 423 } 424 425 func TestTransporterName(t *testing.T) { 426 SetTransporter(standard.NewTransporter) 427 assert.DeepEqual(t, "standard", GetTransporterName()) 428 429 SetTransporter(newMockTransporter) 430 assert.DeepEqual(t, "route", GetTransporterName()) 431 } 432 433 func newMockTransporter(options *config.Options) network.Transporter { 434 return &mockTransporter{} 435 } 436 437 type mockTransporter struct{} 438 439 func (m *mockTransporter) ListenAndServe(onData network.OnData) (err error) { 440 panic("implement me") 441 } 442 443 func (m *mockTransporter) Close() error { 444 panic("implement me") 445 } 446 447 func (m *mockTransporter) Shutdown(ctx context.Context) error { 448 panic("implement me") 449 } 450 451 func TestRenderHtmlOfGlobWithAutoRender(t *testing.T) { 452 opt := config.NewOptions([]config.Option{}) 453 opt.AutoReloadRender = true 454 e := NewEngine(opt) 455 e.Delims("{[{", "}]}") 456 e.SetFuncMap(template.FuncMap{ 457 "formatAsDate": formatAsDate, 458 }) 459 e.LoadHTMLGlob("../common/testdata/template/htmltemplate.html") 460 e.GET("/templateName", func(c context.Context, ctx *app.RequestContext) { 461 ctx.HTML(http.StatusOK, "htmltemplate.html", map[string]interface{}{ 462 "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), 463 }) 464 }) 465 rr := performRequest(e, "GET", "/templateName") 466 b, _ := ioutil.ReadAll(rr.Body) 467 assert.DeepEqual(t, consts.StatusOK, rr.Code) 468 assert.DeepEqual(t, []byte("<h1>Date: 2017/07/01</h1>"), b) 469 assert.DeepEqual(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) 470 } 471 472 func TestSetClientIPAndSetFormValue(t *testing.T) { 473 opt := config.NewOptions([]config.Option{}) 474 e := NewEngine(opt) 475 e.SetClientIPFunc(func(ctx *app.RequestContext) string { 476 return "1.1.1.1" 477 }) 478 e.SetFormValueFunc(func(requestContext *app.RequestContext, s string) []byte { 479 return []byte(s) 480 }) 481 e.GET("/ping", func(c context.Context, ctx *app.RequestContext) { 482 assert.DeepEqual(t, ctx.ClientIP(), "1.1.1.1") 483 assert.DeepEqual(t, string(ctx.FormValue("key")), "key") 484 }) 485 486 _ = performRequest(e, "GET", "/ping") 487 } 488 489 func TestRenderHtmlOfFilesWithAutoRender(t *testing.T) { 490 opt := config.NewOptions([]config.Option{}) 491 opt.AutoReloadRender = true 492 e := NewEngine(opt) 493 e.Delims("{[{", "}]}") 494 e.SetFuncMap(template.FuncMap{ 495 "formatAsDate": formatAsDate, 496 }) 497 e.LoadHTMLFiles("../common/testdata/template/htmltemplate.html") 498 e.GET("/templateName", func(c context.Context, ctx *app.RequestContext) { 499 ctx.HTML(http.StatusOK, "htmltemplate.html", map[string]interface{}{ 500 "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), 501 }) 502 }) 503 rr := performRequest(e, "GET", "/templateName") 504 b, _ := ioutil.ReadAll(rr.Body) 505 assert.DeepEqual(t, consts.StatusOK, rr.Code) 506 assert.DeepEqual(t, []byte("<h1>Date: 2017/07/01</h1>"), b) 507 assert.DeepEqual(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) 508 } 509 510 func TestSetEngineRun(t *testing.T) { 511 e := NewEngine(config.NewOptions(nil)) 512 e.Init() 513 assert.True(t, !e.IsRunning()) 514 e.MarkAsRunning() 515 assert.True(t, e.IsRunning()) 516 } 517 518 type mockConn struct{} 519 520 func (m *mockConn) SetWriteTimeout(t time.Duration) error { 521 // TODO implement me 522 panic("implement me") 523 } 524 525 func (m *mockConn) ReadBinary(n int) (p []byte, err error) { 526 panic("implement me") 527 } 528 529 func (m *mockConn) Handshake() error { 530 return nil 531 } 532 533 func (m *mockConn) ConnectionState() tls.ConnectionState { 534 return tls.ConnectionState{ 535 NegotiatedProtocol: "h2", 536 } 537 } 538 539 func (m *mockConn) SetReadTimeout(t time.Duration) error { 540 return nil 541 } 542 543 func (m *mockConn) Read(b []byte) (n int, err error) { 544 panic("implement me") 545 } 546 547 func (m *mockConn) Write(b []byte) (n int, err error) { 548 panic("implement me") 549 } 550 551 func (m *mockConn) Close() error { 552 panic("implement me") 553 } 554 555 func (m *mockConn) LocalAddr() net.Addr { 556 panic("implement me") 557 } 558 559 func (m *mockConn) RemoteAddr() net.Addr { 560 return &net.TCPAddr{ 561 IP: net.ParseIP("126.0.0.5"), 562 Port: 8888, 563 Zone: "", 564 } 565 } 566 567 func (m *mockConn) SetDeadline(t time.Time) error { 568 panic("implement me") 569 } 570 571 func (m *mockConn) SetReadDeadline(t time.Time) error { 572 panic("implement me") 573 } 574 575 func (m *mockConn) SetWriteDeadline(t time.Time) error { 576 panic("implement me") 577 } 578 579 func (m *mockConn) Release() error { 580 panic("implement me") 581 } 582 583 func (m *mockConn) Peek(i int) ([]byte, error) { 584 panic("implement me") 585 } 586 587 func (m *mockConn) Skip(n int) error { 588 panic("implement me") 589 } 590 591 func (m *mockConn) ReadByte() (byte, error) { 592 panic("implement me") 593 } 594 595 func (m *mockConn) Next(i int) ([]byte, error) { 596 panic("implement me") 597 } 598 599 func (m *mockConn) Len() int { 600 panic("implement me") 601 } 602 603 func (m *mockConn) Malloc(n int) (buf []byte, err error) { 604 panic("implement me") 605 } 606 607 func (m *mockConn) WriteBinary(b []byte) (n int, err error) { 608 panic("implement me") 609 } 610 611 func (m *mockConn) Flush() error { 612 panic("implement me") 613 } 614 615 type fakeTransporter struct{} 616 617 func (f *fakeTransporter) Close() error { 618 // TODO implement me 619 panic("implement me") 620 } 621 622 func (f *fakeTransporter) Shutdown(ctx context.Context) error { 623 // TODO implement me 624 panic("implement me") 625 } 626 627 func (f *fakeTransporter) ListenAndServe(onData network.OnData) error { 628 // TODO implement me 629 panic("implement me") 630 } 631 632 type mockBinder struct{} 633 634 func (m *mockBinder) Name() string { 635 return "test binder" 636 } 637 638 func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { 639 return nil 640 } 641 642 func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { 643 return nil 644 } 645 646 func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { 647 return nil 648 } 649 650 func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { 651 return nil 652 } 653 654 func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { 655 return nil 656 } 657 658 func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { 659 return nil 660 } 661 662 func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { 663 return nil 664 } 665 666 func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { 667 return nil 668 } 669 670 type mockValidator struct{} 671 672 func (m *mockValidator) ValidateStruct(interface{}) error { 673 return fmt.Errorf("test mock") 674 } 675 676 func (m *mockValidator) Engine() interface{} { 677 return nil 678 } 679 680 func (m *mockValidator) ValidateTag() string { 681 return "vd" 682 } 683 684 type mockNonValidator struct{} 685 686 func (m *mockNonValidator) ValidateStruct(interface{}) error { 687 return fmt.Errorf("test mock") 688 } 689 690 func TestInitBinderAndValidator(t *testing.T) { 691 defer func() { 692 if r := recover(); r != nil { 693 t.Errorf("unexpected panic, %v", r) 694 } 695 }() 696 opt := config.NewOptions([]config.Option{}) 697 bindConfig := binding.NewBindConfig() 698 bindConfig.LooseZeroMode = true 699 opt.BindConfig = bindConfig 700 binder := &mockBinder{} 701 opt.CustomBinder = binder 702 validator := &mockValidator{} 703 opt.CustomValidator = validator 704 NewEngine(opt) 705 validateConfig := binding.NewValidateConfig() 706 opt.ValidateConfig = validateConfig 707 opt.CustomValidator = nil 708 NewEngine(opt) 709 } 710 711 func TestInitBinderAndValidatorForPanic(t *testing.T) { 712 defer func() { 713 if r := recover(); r == nil { 714 t.Errorf("expect a panic, but get nil") 715 } 716 }() 717 opt := config.NewOptions([]config.Option{}) 718 bindConfig := binding.NewBindConfig() 719 bindConfig.LooseZeroMode = true 720 opt.BindConfig = bindConfig 721 binder := &mockBinder{} 722 opt.CustomBinder = binder 723 nonValidator := &mockNonValidator{} 724 opt.CustomValidator = nonValidator 725 NewEngine(opt) 726 } 727 728 func TestBindConfig(t *testing.T) { 729 type Req struct { 730 A int `query:"a"` 731 } 732 opt := config.NewOptions([]config.Option{}) 733 bindConfig := binding.NewBindConfig() 734 bindConfig.LooseZeroMode = false 735 opt.BindConfig = bindConfig 736 e := NewEngine(opt) 737 e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 738 var req Req 739 err := ctx.BindAndValidate(&req) 740 if err == nil { 741 t.Fatal("expect an error") 742 } 743 }) 744 performRequest(e, "GET", "/bind?a=") 745 746 bindConfig = binding.NewBindConfig() 747 bindConfig.LooseZeroMode = true 748 opt.BindConfig = bindConfig 749 e = NewEngine(opt) 750 e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 751 var req Req 752 err := ctx.BindAndValidate(&req) 753 if err != nil { 754 t.Fatal("unexpected error") 755 } 756 assert.DeepEqual(t, 0, req.A) 757 }) 758 performRequest(e, "GET", "/bind?a=") 759 } 760 761 type ValidateError struct { 762 ErrType, FailField, Msg string 763 } 764 765 // Error implements error interface. 766 func (e *ValidateError) Error() string { 767 if e.Msg != "" { 768 return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg 769 } 770 return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" 771 } 772 773 func TestValidateConfigSetErrorFactory(t *testing.T) { 774 type TestValidate struct { 775 B int `query:"b" vd:"$>100"` 776 } 777 opt := config.NewOptions([]config.Option{}) 778 CustomValidateErrFunc := func(failField, msg string) error { 779 err := ValidateError{ 780 ErrType: "validateErr", 781 FailField: "[validateFailField]: " + failField, 782 Msg: "[validateErrMsg]: " + msg, 783 } 784 785 return &err 786 } 787 788 validateConfig := binding.NewValidateConfig() 789 validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) 790 opt.ValidateConfig = validateConfig 791 e := NewEngine(opt) 792 e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 793 var req TestValidate 794 err := ctx.BindAndValidate(&req) 795 if err == nil { 796 t.Fatal("expect an error") 797 } 798 assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) 799 }) 800 performRequest(e, "GET", "/bind?b=1") 801 } 802 803 func TestCustomBinder(t *testing.T) { 804 type Req struct { 805 A int `query:"a"` 806 } 807 opt := config.NewOptions([]config.Option{}) 808 opt.CustomBinder = &mockBinder{} 809 e := NewEngine(opt) 810 e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 811 var req Req 812 err := ctx.BindAndValidate(&req) 813 if err != nil { 814 t.Fatal("unexpected error") 815 } 816 assert.NotEqual(t, 2, req.A) 817 }) 818 performRequest(e, "GET", "/bind?a=2") 819 } 820 821 func TestValidateRegValidateFunc(t *testing.T) { 822 type Req struct { 823 A int `query:"a" vd:"f($)"` 824 } 825 opt := config.NewOptions([]config.Option{}) 826 validateConfig := &binding.ValidateConfig{} 827 validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { 828 return fmt.Errorf("test error") 829 }) 830 e := NewEngine(opt) 831 e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { 832 var req Req 833 err := ctx.BindAndValidate(&req) 834 assert.NotNil(t, err) 835 assert.DeepEqual(t, "test error", err.Error()) 836 }) 837 performRequest(e, "GET", "/validate?a=2") 838 } 839 840 func TestCustomValidator(t *testing.T) { 841 type Req struct { 842 A int `query:"a" vd:"d($)"` 843 } 844 opt := config.NewOptions([]config.Option{}) 845 validateConfig := &binding.ValidateConfig{} 846 validateConfig.MustRegValidateFunc("d", func(args ...interface{}) error { 847 return fmt.Errorf("test error") 848 }) 849 opt.CustomValidator = &mockValidator{} 850 e := NewEngine(opt) 851 e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { 852 var req Req 853 err := ctx.BindAndValidate(&req) 854 assert.NotNil(t, err) 855 assert.DeepEqual(t, "test mock", err.Error()) 856 }) 857 performRequest(e, "GET", "/validate?a=2") 858 } 859 860 var errTestDeregsitry = fmt.Errorf("test deregsitry error") 861 862 type mockDeregsitryErr struct{} 863 864 var _ registry.Registry = &mockDeregsitryErr{} 865 866 func (e mockDeregsitryErr) Register(*registry.Info) error { 867 return nil 868 } 869 870 func (e mockDeregsitryErr) Deregister(*registry.Info) error { 871 return errTestDeregsitry 872 } 873 874 func TestEngineShutdown(t *testing.T) { 875 defaultTransporter = standard.NewTransporter 876 mockCtxCallback := func(ctx context.Context) {} 877 // Test case 1: serve not running error 878 opt := config.NewOptions(nil) 879 opt.Addr = "127.0.0.1:10027" 880 engine := NewEngine(opt) 881 ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) 882 defer cancel1() 883 err := engine.Shutdown(ctx1) 884 assert.DeepEqual(t, errStatusNotRunning, err) 885 886 // Test case 2: serve successfully running and shutdown 887 engine = NewEngine(opt) 888 engine.OnShutdown = []CtxCallback{mockCtxCallback} 889 go func() { 890 engine.Run() 891 }() 892 // wait for engine to start 893 time.Sleep(1 * time.Second) 894 895 ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) 896 defer cancel2() 897 err = engine.Shutdown(ctx2) 898 assert.Nil(t, err) 899 assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status)) 900 901 // Test case 3: serve successfully running and shutdown with deregistry error 902 engine = NewEngine(opt) 903 engine.OnShutdown = []CtxCallback{mockCtxCallback} 904 engine.options.Registry = &mockDeregsitryErr{} 905 go func() { 906 engine.Run() 907 }() 908 // wait for engine to start 909 time.Sleep(1 * time.Second) 910 911 ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second) 912 defer cancel3() 913 err = engine.Shutdown(ctx3) 914 assert.DeepEqual(t, errTestDeregsitry, err) 915 assert.DeepEqual(t, statusShutdown, atomic.LoadUint32(&engine.status)) 916 } 917 918 type mockStreamer struct{} 919 920 type mockProtocolServer struct{} 921 922 func (s *mockStreamer) Serve(c context.Context, conn network.StreamConn) error { 923 return nil 924 } 925 926 func (s *mockProtocolServer) Serve(c context.Context, conn network.Conn) error { 927 return nil 928 } 929 930 type mockStreamConn struct { 931 network.StreamConn 932 version string 933 } 934 935 var _ network.StreamConn = &mockStreamConn{} 936 937 func (m *mockStreamConn) GetVersion() uint32 { 938 return network.Version1 939 } 940 941 func TestEngineServeStream(t *testing.T) { 942 engine := &Engine{ 943 options: &config.Options{ 944 ALPN: true, 945 TLS: &tls.Config{}, 946 }, 947 protocolStreamServers: map[string]protocol.StreamServer{ 948 suite.HTTP3: &mockStreamer{}, 949 }, 950 } 951 952 // Test ALPN path 953 conn := &mockStreamConn{version: suite.HTTP3} 954 err := engine.ServeStream(context.Background(), conn) 955 assert.Nil(t, err) 956 957 // Test default path 958 engine.options.ALPN = false 959 conn = &mockStreamConn{} 960 err = engine.ServeStream(context.Background(), conn) 961 assert.Nil(t, err) 962 963 // Test unsupported protocol 964 engine.protocolStreamServers = map[string]protocol.StreamServer{} 965 conn = &mockStreamConn{} 966 err = engine.ServeStream(context.Background(), conn) 967 assert.DeepEqual(t, errs.ErrNotSupportProtocol, err) 968 } 969 970 func TestEngineServe(t *testing.T) { 971 engine := NewEngine(config.NewOptions(nil)) 972 engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} 973 engine.protocolServers[suite.HTTP2] = &mockProtocolServer{} 974 975 // test H2C path 976 ctx := context.Background() 977 conn := mock.NewConn("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") 978 engine.options.H2C = true 979 err := engine.Serve(ctx, conn) 980 assert.Nil(t, err) 981 982 // test ALPN path 983 ctx = context.Background() 984 conn = mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") 985 engine.options.H2C = false 986 engine.options.ALPN = true 987 engine.options.TLS = &tls.Config{} 988 err = engine.Serve(ctx, conn) 989 assert.Nil(t, err) 990 991 // test HTTP1 path 992 engine.options.ALPN = false 993 err = engine.Serve(ctx, conn) 994 assert.Nil(t, err) 995 } 996 997 func TestOndata(t *testing.T) { 998 ctx := context.Background() 999 engine := NewEngine(config.NewOptions(nil)) 1000 1001 // test stream conn 1002 streamConn := &mockStreamConn{version: suite.HTTP3} 1003 engine.protocolStreamServers[suite.HTTP3] = &mockStreamer{} 1004 err := engine.onData(ctx, streamConn) 1005 assert.Nil(t, err) 1006 1007 // test conn 1008 conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") 1009 engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} 1010 err = engine.onData(ctx, conn) 1011 assert.Nil(t, err) 1012 } 1013 1014 func TestAcquireHijackConn(t *testing.T) { 1015 engine := &Engine{ 1016 NoHijackConnPool: false, 1017 } 1018 // test conn pool 1019 conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") 1020 hijackConn := engine.acquireHijackConn(conn) 1021 assert.NotNil(t, hijackConn) 1022 assert.NotNil(t, hijackConn.Conn) 1023 assert.DeepEqual(t, engine, hijackConn.e) 1024 assert.DeepEqual(t, conn, hijackConn.Conn) 1025 1026 // test no conn pool 1027 engine.NoHijackConnPool = true 1028 hijackConn = engine.acquireHijackConn(conn) 1029 assert.NotNil(t, hijackConn) 1030 assert.NotNil(t, hijackConn.Conn) 1031 assert.DeepEqual(t, engine, hijackConn.e) 1032 assert.DeepEqual(t, conn, hijackConn.Conn) 1033 } 1034 1035 func TestHandleParamsReassignInHandleFunc(t *testing.T) { 1036 e := NewEngine(config.NewOptions(nil)) 1037 routes := []string{ 1038 "/:a/:b/:c", 1039 } 1040 for _, r := range routes { 1041 e.GET(r, func(c context.Context, ctx *app.RequestContext) { 1042 ctx.Params = make([]param.Param, 1) 1043 ctx.String(consts.StatusOK, "") 1044 }) 1045 } 1046 testRoutes := []string{ 1047 "/aaa/bbb/ccc", 1048 "/asd/alskja/alkdjad", 1049 "/asd/alskja/alkdjad", 1050 "/asd/alskja/alkdjad", 1051 "/asd/alskja/alkdjad", 1052 "/alksjdlakjd/ooo/askda", 1053 "/alksjdlakjd/ooo/askda", 1054 "/alksjdlakjd/ooo/askda", 1055 } 1056 ctx := e.ctxPool.Get().(*app.RequestContext) 1057 for _, tr := range testRoutes { 1058 r := protocol.NewRequest(http.MethodGet, tr, nil) 1059 r.CopyTo(&ctx.Request) 1060 e.ServeHTTP(context.Background(), ctx) 1061 ctx.ResetWithoutConn() 1062 } 1063 }