github.com/cloudwego/hertz@v0.9.3/pkg/app/server/hertz_test.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package server 18 19 import ( 20 "bytes" 21 "context" 22 "errors" 23 "fmt" 24 "html/template" 25 "io" 26 "io/ioutil" 27 "net" 28 "net/http" 29 "strings" 30 "sync" 31 "sync/atomic" 32 "testing" 33 "time" 34 35 "github.com/cloudwego/hertz/pkg/app" 36 c "github.com/cloudwego/hertz/pkg/app/client" 37 "github.com/cloudwego/hertz/pkg/app/server/binding" 38 "github.com/cloudwego/hertz/pkg/app/server/registry" 39 "github.com/cloudwego/hertz/pkg/common/config" 40 errs "github.com/cloudwego/hertz/pkg/common/errors" 41 "github.com/cloudwego/hertz/pkg/common/hlog" 42 "github.com/cloudwego/hertz/pkg/common/test/assert" 43 "github.com/cloudwego/hertz/pkg/common/test/mock" 44 "github.com/cloudwego/hertz/pkg/common/utils" 45 "github.com/cloudwego/hertz/pkg/network" 46 "github.com/cloudwego/hertz/pkg/network/standard" 47 "github.com/cloudwego/hertz/pkg/protocol" 48 "github.com/cloudwego/hertz/pkg/protocol/consts" 49 "github.com/cloudwego/hertz/pkg/protocol/http1/req" 50 "github.com/cloudwego/hertz/pkg/protocol/http1/resp" 51 "github.com/cloudwego/hertz/pkg/route/param" 52 ) 53 54 func TestHertz_Run(t *testing.T) { 55 hertz := Default(WithHostPorts("127.0.0.1:6666")) 56 hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { 57 time.Sleep(time.Second) 58 path := ctx.Request.URI().PathOriginal() 59 ctx.SetBodyString(string(path)) 60 }) 61 62 testint := uint32(0) 63 hertz.Engine.OnShutdown = append(hertz.OnShutdown, func(ctx context.Context) { 64 atomic.StoreUint32(&testint, 1) 65 }) 66 67 assert.Assert(t, len(hertz.Handlers) == 1) 68 69 go hertz.Spin() 70 time.Sleep(100 * time.Millisecond) 71 72 hertz.Close() 73 resp, err := http.Get("http://127.0.0.1:6666/test") 74 assert.NotNil(t, err) 75 assert.Nil(t, resp) 76 assert.DeepEqual(t, uint32(0), atomic.LoadUint32(&testint)) 77 } 78 79 func TestHertz_GracefulShutdown(t *testing.T) { 80 engine := New(WithHostPorts("127.0.0.1:6667")) 81 engine.GET("/test", func(c context.Context, ctx *app.RequestContext) { 82 time.Sleep(time.Second * 2) 83 path := ctx.Request.URI().PathOriginal() 84 ctx.SetBodyString(string(path)) 85 }) 86 engine.GET("/test2", func(c context.Context, ctx *app.RequestContext) {}) 87 88 testint := uint32(0) 89 testint2 := uint32(0) 90 testint3 := uint32(0) 91 engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { 92 atomic.StoreUint32(&testint, 1) 93 }) 94 engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { 95 atomic.StoreUint32(&testint2, 2) 96 }) 97 engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { 98 time.Sleep(2 * time.Second) 99 atomic.StoreUint32(&testint3, 3) 100 }) 101 102 go engine.Spin() 103 time.Sleep(time.Millisecond) 104 105 hc := http.Client{Timeout: time.Second} 106 var err error 107 var resp *http.Response 108 ch := make(chan struct{}) 109 ch2 := make(chan struct{}) 110 go func() { 111 ticker := time.NewTicker(time.Millisecond * 100) 112 defer ticker.Stop() 113 for range ticker.C { 114 t.Logf("[%v]begin listening\n", time.Now()) 115 _, err2 := hc.Get("http://127.0.0.1:6667/test2") 116 if err2 != nil { 117 t.Logf("[%v]listening closed: %v", time.Now(), err2) 118 ch2 <- struct{}{} 119 break 120 } 121 } 122 }() 123 go func() { 124 t.Logf("[%v]begin request\n", time.Now()) 125 resp, err = http.Get("http://127.0.0.1:6667/test") 126 t.Logf("[%v]end request\n", time.Now()) 127 ch <- struct{}{} 128 }() 129 130 time.Sleep(time.Second * 1) 131 start := time.Now() 132 ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) 133 t.Logf("[%v]begin shutdown\n", start) 134 engine.Shutdown(ctx) 135 end := time.Now() 136 t.Logf("[%v]end shutdown\n", end) 137 138 <-ch 139 assert.Nil(t, err) 140 assert.NotNil(t, resp) 141 assert.DeepEqual(t, true, resp.Close) 142 assert.DeepEqual(t, uint32(1), atomic.LoadUint32(&testint)) 143 assert.DeepEqual(t, uint32(2), atomic.LoadUint32(&testint2)) 144 assert.DeepEqual(t, uint32(3), atomic.LoadUint32(&testint3)) 145 146 <-ch2 147 148 cancel() 149 } 150 151 func TestLoadHTMLGlob(t *testing.T) { 152 engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:8893")) 153 engine.Delims("{[{", "}]}") 154 engine.LoadHTMLGlob("../../common/testdata/template/index.tmpl") 155 engine.GET("/index", func(c context.Context, ctx *app.RequestContext) { 156 ctx.HTML(consts.StatusOK, "index.tmpl", utils.H{ 157 "title": "Main website", 158 }) 159 }) 160 go engine.Run() 161 defer func() { 162 engine.Close() 163 }() 164 time.Sleep(1 * time.Second) 165 resp, _ := http.Get("http://127.0.0.1:8893/index") 166 assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) 167 b := make([]byte, 100) 168 n, _ := resp.Body.Read(b) 169 const expected = `<html><h1>Main website</h1></html>` 170 171 assert.DeepEqual(t, expected, string(b[0:n])) 172 } 173 174 func TestLoadHTMLFiles(t *testing.T) { 175 engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:8891")) 176 engine.Delims("{[{", "}]}") 177 engine.SetFuncMap(template.FuncMap{ 178 "formatAsDate": formatAsDate, 179 }) 180 engine.LoadHTMLFiles("../../common/testdata/template/htmltemplate.html", "../../common/testdata/template/index.tmpl") 181 182 engine.GET("/raw", func(c context.Context, ctx *app.RequestContext) { 183 ctx.HTML(consts.StatusOK, "htmltemplate.html", map[string]interface{}{ 184 "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), 185 }) 186 }) 187 go engine.Run() 188 defer func() { 189 engine.Close() 190 }() 191 time.Sleep(1 * time.Second) 192 resp, _ := http.Get("http://127.0.0.1:8891/raw") 193 assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) 194 b := make([]byte, 100) 195 n, _ := resp.Body.Read(b) 196 assert.DeepEqual(t, "<h1>Date: 2017/07/01</h1>", string(b[0:n])) 197 } 198 199 func formatAsDate(t time.Time) string { 200 year, month, day := t.Date() 201 return fmt.Sprintf("%d/%02d/%02d", year, month, day) 202 } 203 204 // copied from router 205 var ( 206 default400Body = []byte("400 bad request") 207 requiredHostBody = []byte("missing required Host header") 208 ) 209 210 func TestServer_Use(t *testing.T) { 211 router := New() 212 router.Use(func(c context.Context, ctx *app.RequestContext) {}) 213 assert.DeepEqual(t, 1, len(router.Handlers)) 214 router.Use(func(c context.Context, ctx *app.RequestContext) {}) 215 assert.DeepEqual(t, 2, len(router.Handlers)) 216 } 217 218 func Test_getServerName(t *testing.T) { 219 engine := New() 220 assert.DeepEqual(t, []byte("hertz"), engine.GetServerName()) 221 ss := New() 222 ss.Name = "test_name" 223 assert.DeepEqual(t, []byte("test_name"), ss.GetServerName()) 224 } 225 226 func TestServer_Run(t *testing.T) { 227 hertz := New(WithHostPorts("127.0.0.1:8899")) 228 hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { 229 path := ctx.Request.URI().PathOriginal() 230 ctx.SetBodyString(string(path)) 231 }) 232 hertz.POST("/redirect", func(c context.Context, ctx *app.RequestContext) { 233 ctx.Redirect(consts.StatusMovedPermanently, []byte("http://127.0.0.1:8899/test")) 234 }) 235 go hertz.Run() 236 time.Sleep(1 * time.Second) 237 resp, err := http.Get("http://127.0.0.1:8899/test") 238 assert.Nil(t, err) 239 assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) 240 b := make([]byte, 5) 241 resp.Body.Read(b) 242 assert.DeepEqual(t, "/test", string(b)) 243 244 resp, err = http.Get("http://127.0.0.1:8899/foo") 245 assert.Nil(t, err) 246 assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode) 247 248 resp, err = http.Post("http://127.0.0.1:8899/redirect", "", nil) 249 assert.Nil(t, err) 250 assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) 251 b = make([]byte, 5) 252 resp.Body.Read(b) 253 assert.DeepEqual(t, "/test", string(b)) 254 255 ctx, cancel := context.WithTimeout(context.Background(), 0) 256 defer cancel() 257 _ = hertz.Shutdown(ctx) 258 } 259 260 func TestNotAbsolutePath(t *testing.T) { 261 engine := New(WithHostPorts("127.0.0.1:9990")) 262 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 263 ctx.Write(ctx.Request.Body()) 264 }) 265 engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { 266 ctx.Write(ctx.Request.Body()) 267 }) 268 go engine.Run() 269 defer func() { 270 engine.Close() 271 }() 272 time.Sleep(1 * time.Second) 273 274 s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" 275 zr := mock.NewZeroCopyReader(s) 276 277 ctx := app.NewContext(0) 278 if err := req.Read(&ctx.Request, zr); err != nil { 279 t.Fatalf("unexpected error: %s", err) 280 } 281 engine.ServeHTTP(context.Background(), ctx) 282 assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) 283 assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body()) 284 285 s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" 286 zr = mock.NewZeroCopyReader(s) 287 288 ctx = app.NewContext(0) 289 if err := req.Read(&ctx.Request, zr); err != nil { 290 t.Fatalf("unexpected error: %s", err) 291 } 292 engine.ServeHTTP(context.Background(), ctx) 293 assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) 294 assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body()) 295 } 296 297 func TestNotAbsolutePathWithRawPath(t *testing.T) { 298 engine := New(WithHostPorts("127.0.0.1:9991"), WithUseRawPath(true)) 299 const ( 300 MiddlewareKey = "middleware_key" 301 MiddlewareValue = "middleware_value" 302 ) 303 engine.Use(func(c context.Context, ctx *app.RequestContext) { 304 ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue) 305 }) 306 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 307 }) 308 engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { 309 }) 310 go engine.Run() 311 defer func() { 312 engine.Close() 313 }() 314 time.Sleep(1 * time.Second) 315 316 s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" 317 zr := mock.NewZeroCopyReader(s) 318 319 ctx := app.NewContext(0) 320 if err := req.Read(&ctx.Request, zr); err != nil { 321 t.Fatalf("unexpected error: %s", err) 322 } 323 engine.ServeHTTP(context.Background(), ctx) 324 assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) 325 assert.DeepEqual(t, default400Body, ctx.Response.Body()) 326 gh := ctx.Response.Header.Get(MiddlewareKey) 327 assert.DeepEqual(t, MiddlewareValue, gh) 328 329 s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" 330 zr = mock.NewZeroCopyReader(s) 331 332 ctx = app.NewContext(0) 333 if err := req.Read(&ctx.Request, zr); err != nil { 334 t.Fatalf("unexpected error: %s", err) 335 } 336 engine.ServeHTTP(context.Background(), ctx) 337 assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) 338 assert.DeepEqual(t, default400Body, ctx.Response.Body()) 339 gh = ctx.Response.Header.Get(MiddlewareKey) 340 assert.DeepEqual(t, MiddlewareValue, gh) 341 } 342 343 func TestNotValidHost(t *testing.T) { 344 engine := New(WithHostPorts("127.0.0.1:9992")) 345 const ( 346 MiddlewareKey = "middleware_key" 347 MiddlewareValue = "middleware_value" 348 ) 349 engine.Use(func(c context.Context, ctx *app.RequestContext) { 350 ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue) 351 }) 352 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 353 }) 354 engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { 355 }) 356 357 s := "POST ?a=b HTTP/1.1\r\nHost: \r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" 358 zr := mock.NewZeroCopyReader(s) 359 360 ctx := app.NewContext(0) 361 if err := req.Read(&ctx.Request, zr); err != nil { 362 t.Fatalf("unexpected error: %s", err) 363 } 364 engine.ServeHTTP(context.Background(), ctx) 365 assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) 366 assert.DeepEqual(t, requiredHostBody, ctx.Response.Body()) 367 gh := ctx.Response.Header.Get(MiddlewareKey) 368 assert.DeepEqual(t, MiddlewareValue, gh) 369 370 s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" 371 zr = mock.NewZeroCopyReader(s) 372 373 ctx = app.NewContext(0) 374 if err := req.Read(&ctx.Request, zr); err != nil { 375 t.Fatalf("unexpected error: %s", err) 376 } 377 engine.ServeHTTP(context.Background(), ctx) 378 assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) 379 assert.DeepEqual(t, requiredHostBody, ctx.Response.Body()) 380 gh = ctx.Response.Header.Get(MiddlewareKey) 381 assert.DeepEqual(t, MiddlewareValue, gh) 382 } 383 384 func TestWithBasePath(t *testing.T) { 385 engine := New(WithBasePath("/hertz"), WithHostPorts("127.0.0.1:19898")) 386 engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { 387 }) 388 go engine.Run() 389 defer func() { 390 engine.Close() 391 }() 392 time.Sleep(1 * time.Second) 393 var r http.Request 394 r.ParseForm() 395 r.Form.Add("xxxxxx", "xxx") 396 body := strings.NewReader(r.Form.Encode()) 397 resp, err := http.Post("http://127.0.0.1:19898/hertz/test", "application/x-www-form-urlencoded", body) 398 assert.Nil(t, err) 399 assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) 400 } 401 402 func TestNotEnoughBodySize(t *testing.T) { 403 engine := New(WithMaxRequestBodySize(5), WithHostPorts("127.0.0.1:8889")) 404 engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { 405 }) 406 go engine.Run() 407 defer func() { 408 engine.Close() 409 }() 410 time.Sleep(1 * time.Second) 411 var r http.Request 412 r.ParseForm() 413 r.Form.Add("xxxxxx", "xxx") 414 body := strings.NewReader(r.Form.Encode()) 415 resp, err := http.Post("http://127.0.0.1:8889/test", "application/x-www-form-urlencoded", body) 416 assert.Nil(t, err) 417 assert.DeepEqual(t, 413, resp.StatusCode) 418 bodyBytes, _ := ioutil.ReadAll(resp.Body) 419 assert.DeepEqual(t, "Request Entity Too Large", string(bodyBytes)) 420 } 421 422 func TestEnoughBodySize(t *testing.T) { 423 engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:8892")) 424 engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { 425 }) 426 go engine.Run() 427 defer func() { 428 engine.Close() 429 }() 430 time.Sleep(1 * time.Second) 431 var r http.Request 432 r.ParseForm() 433 r.Form.Add("xxxxxx", "xxx") 434 body := strings.NewReader(r.Form.Encode()) 435 resp, _ := http.Post("http://127.0.0.1:8892/test", "application/x-www-form-urlencoded", body) 436 assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) 437 } 438 439 func TestRequestCtxHijack(t *testing.T) { 440 hijackStartCh := make(chan struct{}) 441 hijackStopCh := make(chan struct{}) 442 engine := New() 443 engine.Init() 444 445 engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { 446 if ctx.Hijacked() { 447 t.Error("connection mustn't be hijacked") 448 } 449 ctx.Hijack(func(c network.Conn) { 450 <-hijackStartCh 451 452 b := make([]byte, 1) 453 // ping-pong echo via hijacked conn 454 for { 455 n, err := c.Read(b) 456 if n != 1 { 457 if err == io.EOF { 458 close(hijackStopCh) 459 return 460 } 461 if err != nil { 462 t.Errorf("unexpected error: %s", err) 463 } 464 t.Errorf("unexpected number of bytes read: %d. Expecting 1", n) 465 } 466 if _, err = c.Write(b); err != nil { 467 t.Errorf("unexpected error when writing data: %s", err) 468 } 469 } 470 }) 471 if !ctx.Hijacked() { 472 t.Error("connection must be hijacked") 473 } 474 ctx.Data(consts.StatusOK, "foo/bar", []byte("hijack it!")) 475 }) 476 477 hijackedString := "foobar baz hijacked!!!" 478 479 c := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" + hijackedString) 480 481 ch := make(chan error) 482 go func() { 483 ch <- engine.Serve(context.Background(), c) 484 }() 485 486 time.Sleep(100 * time.Millisecond) 487 488 close(hijackStartCh) 489 490 if err := <-ch; err != nil { 491 if !errors.Is(err, errs.ErrHijacked) { 492 t.Fatalf("Unexpected error from serveConn: %s", err) 493 } 494 } 495 verifyResponse(t, c.WriterRecorder(), consts.StatusOK, "foo/bar", "hijack it!") 496 497 select { 498 case <-hijackStopCh: 499 case <-time.After(100 * time.Millisecond): 500 t.Fatal("timeout") 501 } 502 503 zw := c.WriterRecorder() 504 data, err := zw.ReadBinary(zw.Len()) 505 if err != nil { 506 t.Fatalf("Unexpected error when reading remaining data: %s", err) 507 } 508 if string(data) != hijackedString { 509 t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString) 510 } 511 } 512 513 func verifyResponse(t *testing.T, zr network.Reader, expectedStatusCode int, expectedContentType, expectedBody string) { 514 var r protocol.Response 515 if err := resp.Read(&r, zr); err != nil { 516 t.Fatalf("Unexpected error when parsing response: %s", err) 517 } 518 519 if !bytes.Equal(r.Body(), []byte(expectedBody)) { 520 t.Fatalf("Unexpected body %q. Expected %q", r.Body(), []byte(expectedBody)) 521 } 522 verifyResponseHeader(t, &r.Header, expectedStatusCode, len(r.Body()), expectedContentType, "") 523 } 524 525 func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType, expectedContentEncoding string) { 526 if h.StatusCode() != expectedStatusCode { 527 t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode) 528 } 529 if h.ContentLength() != expectedContentLength { 530 t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength) 531 } 532 if string(h.ContentType()) != expectedContentType { 533 t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType(), expectedContentType) 534 } 535 if string(h.ContentEncoding()) != expectedContentEncoding { 536 t.Fatalf("Unexpected content encoding %q. Expected %q", h.ContentEncoding(), expectedContentEncoding) 537 } 538 } 539 540 func TestParamInconsist(t *testing.T) { 541 mapS := sync.Map{} 542 h := New(WithHostPorts("localhost:10091")) 543 h.GET("/:label", func(c context.Context, ctx *app.RequestContext) { 544 label := ctx.Param("label") 545 x, _ := mapS.LoadOrStore(label, label) 546 labelString := x.(string) 547 if label != labelString { 548 t.Errorf("unexpected label: %s, expected return label: %s", label, labelString) 549 } 550 }) 551 go h.Run() 552 time.Sleep(time.Millisecond * 50) 553 client, _ := c.NewClient() 554 wg := sync.WaitGroup{} 555 tr := func() { 556 defer wg.Done() 557 for i := 0; i < 5000; i++ { 558 client.Get(context.Background(), nil, "http://localhost:10091/test1") 559 } 560 } 561 ti := func() { 562 defer wg.Done() 563 for i := 0; i < 5000; i++ { 564 client.Get(context.Background(), nil, "http://localhost:10091/test2") 565 } 566 } 567 568 for i := 0; i < 30; i++ { 569 go tr() 570 go ti() 571 wg.Add(2) 572 } 573 wg.Wait() 574 } 575 576 func TestDuplicateReleaseBodyStream(t *testing.T) { 577 h := New(WithStreamBody(true), WithHostPorts("localhost:10092")) 578 h.POST("/test", func(ctx context.Context, c *app.RequestContext) { 579 stream := c.RequestBodyStream() 580 c.Response.SetBodyStream(stream, -1) 581 }) 582 go h.Spin() 583 time.Sleep(time.Second) 584 client, _ := c.NewClient(c.WithMaxConnsPerHost(1000000), c.WithDialTimeout(time.Minute)) 585 bodyBytes := make([]byte, 102388) 586 index := 0 587 letterBytes := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 588 for i := 0; i < 102388; i++ { 589 bodyBytes[i] = letterBytes[index] 590 if i%1969 == 0 && i != 0 { 591 index = index + 1 592 } 593 } 594 body := string(bodyBytes) 595 596 wg := sync.WaitGroup{} 597 testFunc := func() { 598 defer wg.Done() 599 r := protocol.NewRequest("POST", "http://localhost:10092/test", nil) 600 r.SetBodyString(body) 601 resp := protocol.AcquireResponse() 602 err := client.Do(context.Background(), r, resp) 603 if err != nil { 604 t.Errorf("unexpected error: %s", err.Error()) 605 } 606 if body != string(resp.Body()) { 607 t.Errorf("unequal body") 608 } 609 } 610 611 for i := 0; i < 10; i++ { 612 wg.Add(1) 613 go testFunc() 614 } 615 wg.Wait() 616 } 617 618 func TestServiceRegisterFailed(t *testing.T) { 619 mockRegErr := errors.New("mock register error") 620 var rCount int32 621 var drCount int32 622 mockRegistry := MockRegistry{ 623 RegisterFunc: func(info *registry.Info) error { 624 atomic.AddInt32(&rCount, 1) 625 return mockRegErr 626 }, 627 DeregisterFunc: func(info *registry.Info) error { 628 atomic.AddInt32(&drCount, 1) 629 return nil 630 }, 631 } 632 var opts []config.Option 633 opts = append(opts, WithRegistry(mockRegistry, nil)) 634 opts = append(opts, WithHostPorts("127.0.0.1:9222")) 635 srv := New(opts...) 636 srv.Spin() 637 time.Sleep(2 * time.Second) 638 assert.Assert(t, atomic.LoadInt32(&rCount) == 1) 639 } 640 641 func TestServiceDeregisterFailed(t *testing.T) { 642 mockDeregErr := errors.New("mock deregister error") 643 var rCount int32 644 var drCount int32 645 mockRegistry := MockRegistry{ 646 RegisterFunc: func(info *registry.Info) error { 647 atomic.AddInt32(&rCount, 1) 648 return nil 649 }, 650 DeregisterFunc: func(info *registry.Info) error { 651 atomic.AddInt32(&drCount, 1) 652 return mockDeregErr 653 }, 654 } 655 var opts []config.Option 656 opts = append(opts, WithRegistry(mockRegistry, nil)) 657 opts = append(opts, WithHostPorts("127.0.0.1:9223")) 658 srv := New(opts...) 659 go srv.Spin() 660 time.Sleep(1 * time.Second) 661 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) 662 defer cancel() 663 _ = srv.Shutdown(ctx) 664 time.Sleep(1 * time.Second) 665 assert.Assert(t, atomic.LoadInt32(&rCount) == 1) 666 assert.Assert(t, atomic.LoadInt32(&drCount) == 1) 667 } 668 669 func TestServiceRegistryInfo(t *testing.T) { 670 registryInfo := ®istry.Info{ 671 Weight: 100, 672 Tags: map[string]string{"aa": "bb"}, 673 ServiceName: "hertz.api.test", 674 } 675 checkInfo := func(info *registry.Info) { 676 assert.Assert(t, info.Weight == registryInfo.Weight) 677 assert.Assert(t, info.ServiceName == "hertz.api.test") 678 assert.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags) 679 assert.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags) 680 } 681 var rCount int32 682 var drCount int32 683 mockRegistry := MockRegistry{ 684 RegisterFunc: func(info *registry.Info) error { 685 checkInfo(info) 686 atomic.AddInt32(&rCount, 1) 687 return nil 688 }, 689 DeregisterFunc: func(info *registry.Info) error { 690 checkInfo(info) 691 atomic.AddInt32(&drCount, 1) 692 return nil 693 }, 694 } 695 var opts []config.Option 696 opts = append(opts, WithRegistry(mockRegistry, registryInfo)) 697 opts = append(opts, WithHostPorts("127.0.0.1:9225")) 698 srv := New(opts...) 699 go srv.Spin() 700 time.Sleep(2 * time.Second) 701 ctx, cancel := context.WithTimeout(context.Background(), 0) 702 defer cancel() 703 _ = srv.Shutdown(ctx) 704 time.Sleep(2 * time.Second) 705 assert.Assert(t, atomic.LoadInt32(&rCount) == 1) 706 assert.Assert(t, atomic.LoadInt32(&drCount) == 1) 707 } 708 709 func TestServiceRegistryNoInitInfo(t *testing.T) { 710 checkInfo := func(info *registry.Info) { 711 assert.Assert(t, info == nil) 712 } 713 var rCount int32 714 var drCount int32 715 mockRegistry := MockRegistry{ 716 RegisterFunc: func(info *registry.Info) error { 717 checkInfo(info) 718 atomic.AddInt32(&rCount, 1) 719 return nil 720 }, 721 DeregisterFunc: func(info *registry.Info) error { 722 checkInfo(info) 723 atomic.AddInt32(&drCount, 1) 724 return nil 725 }, 726 } 727 var opts []config.Option 728 opts = append(opts, WithRegistry(mockRegistry, nil)) 729 opts = append(opts, WithHostPorts("127.0.0.1:9227")) 730 srv := New(opts...) 731 go srv.Spin() 732 time.Sleep(2 * time.Second) 733 ctx, cancel := context.WithTimeout(context.Background(), 0) 734 defer cancel() 735 _ = srv.Shutdown(ctx) 736 time.Sleep(2 * time.Second) 737 assert.Assert(t, atomic.LoadInt32(&rCount) == 1) 738 assert.Assert(t, atomic.LoadInt32(&drCount) == 1) 739 } 740 741 type testTracer struct{} 742 743 func (t testTracer) Start(ctx context.Context, c *app.RequestContext) context.Context { 744 value := 0 745 if v := ctx.Value("testKey"); v != nil { 746 value = v.(int) 747 value++ 748 } 749 return context.WithValue(ctx, "testKey", value) 750 } 751 752 func (t testTracer) Finish(ctx context.Context, c *app.RequestContext) {} 753 754 func TestReuseCtx(t *testing.T) { 755 h := New(WithTracer(testTracer{}), WithHostPorts("localhost:9228")) 756 h.GET("/ping", func(ctx context.Context, c *app.RequestContext) { 757 assert.DeepEqual(t, 0, ctx.Value("testKey").(int)) 758 }) 759 760 go h.Spin() 761 time.Sleep(time.Second) 762 for i := 0; i < 1000; i++ { 763 _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9228/ping") 764 assert.Nil(t, err) 765 } 766 } 767 768 type CloseWithoutResetBuffer interface { 769 CloseNoResetBuffer() error 770 } 771 772 func TestOnprepare(t *testing.T) { 773 h1 := New( 774 WithHostPorts("localhost:9333"), 775 WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { 776 b, err := conn.Peek(3) 777 assert.Nil(t, err) 778 assert.DeepEqual(t, string(b), "GET") 779 if c, ok := conn.(CloseWithoutResetBuffer); ok { 780 c.CloseNoResetBuffer() 781 } else { 782 conn.Close() 783 } 784 return ctx 785 })) 786 h1.GET("/ping", func(ctx context.Context, c *app.RequestContext) { 787 c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) 788 }) 789 790 go h1.Spin() 791 time.Sleep(time.Second) 792 _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9333/ping") 793 assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error()) 794 795 h2 := New( 796 WithOnAccept(func(conn net.Conn) context.Context { 797 conn.Close() 798 return context.Background() 799 }), 800 WithHostPorts("localhost:9331")) 801 h2.GET("/ping", func(ctx context.Context, c *app.RequestContext) { 802 c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) 803 }) 804 go h2.Spin() 805 time.Sleep(time.Second) 806 _, _, err = c.Get(context.Background(), nil, "http://127.0.0.1:9331/ping") 807 if err == nil { 808 t.Fatalf("err should not be nil") 809 } 810 811 h3 := New( 812 WithOnAccept(func(conn net.Conn) context.Context { 813 assert.DeepEqual(t, conn.LocalAddr().String(), "127.0.0.1:9231") 814 return context.Background() 815 }), 816 WithHostPorts("localhost:9231"), 817 WithTransport(standard.NewTransporter)) 818 h3.GET("/ping", func(ctx context.Context, c *app.RequestContext) { 819 c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) 820 }) 821 go h3.Spin() 822 time.Sleep(time.Second) 823 c.Get(context.Background(), nil, "http://127.0.0.1:9231/ping") 824 } 825 826 type lockBuffer struct { 827 sync.Mutex 828 b bytes.Buffer 829 } 830 831 func (l *lockBuffer) Write(p []byte) (int, error) { 832 l.Lock() 833 defer l.Unlock() 834 return l.b.Write(p) 835 } 836 837 func (l *lockBuffer) String() string { 838 l.Lock() 839 defer l.Unlock() 840 return l.b.String() 841 } 842 843 func TestSilentMode(t *testing.T) { 844 hlog.SetSilentMode(true) 845 b := &lockBuffer{b: bytes.Buffer{}} 846 847 hlog.SetOutput(b) 848 849 h := New(WithHostPorts("localhost:9232"), WithTransport(standard.NewTransporter)) 850 h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { 851 ctx.Write([]byte("hello, world")) 852 }) 853 go h.Spin() 854 time.Sleep(time.Second) 855 856 d := standard.NewDialer() 857 conn, _ := d.DialConnection("tcp", "127.0.0.1:9232", 0, nil) 858 conn.Write([]byte("aaa")) 859 conn.Close() 860 861 if strings.Contains(b.String(), "Error") { 862 t.Fatalf("unexpected error in log: %s", b.String()) 863 } 864 } 865 866 func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { 867 h := New( 868 WithHostPorts("localhost:9212"), 869 WithDisableHeaderNamesNormalizing(true), 870 ) 871 headerName := "CASE-senSITive-HEAder-NAME" 872 headerValue := "foobar-baz" 873 succeed := false 874 h.GET("/test", func(c context.Context, ctx *app.RequestContext) { 875 ctx.VisitAllHeaders(func(key, value []byte) { 876 if string(key) == headerName && string(value) == headerValue { 877 succeed = true 878 return 879 } 880 }) 881 if !succeed { 882 t.Fatalf("DisableHeaderNamesNormalizing failed") 883 } else { 884 ctx.Header(headerName, headerValue) 885 } 886 }) 887 888 go h.Spin() 889 time.Sleep(100 * time.Millisecond) 890 891 cli, _ := c.NewClient(c.WithDisableHeaderNamesNormalizing(true)) 892 893 r := protocol.NewRequest("GET", "http://localhost:9212/test", nil) 894 r.Header.DisableNormalizing() 895 r.Header.Set(headerName, headerValue) 896 res := protocol.AcquireResponse() 897 err := cli.Do(context.Background(), r, res) 898 assert.Nil(t, err) 899 assert.DeepEqual(t, headerValue, res.Header.Get(headerName)) 900 } 901 902 func TestBindConfig(t *testing.T) { 903 type Req struct { 904 A int `query:"a"` 905 } 906 bindConfig := binding.NewBindConfig() 907 bindConfig.LooseZeroMode = true 908 h := New( 909 WithHostPorts("localhost:9332"), 910 WithBindConfig(bindConfig)) 911 h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 912 var req Req 913 err := ctx.BindAndValidate(&req) 914 if err != nil { 915 t.Fatal("unexpected error") 916 } 917 }) 918 919 go h.Spin() 920 time.Sleep(100 * time.Millisecond) 921 hc := http.Client{Timeout: time.Second} 922 _, err := hc.Get("http://127.0.0.1:9332/bind?a=") 923 assert.Nil(t, err) 924 925 bindConfig = binding.NewBindConfig() 926 bindConfig.LooseZeroMode = false 927 h2 := New( 928 WithHostPorts("localhost:9448"), 929 WithBindConfig(bindConfig)) 930 h2.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 931 var req Req 932 err := ctx.BindAndValidate(&req) 933 if err == nil { 934 t.Fatal("expect an error") 935 } 936 }) 937 938 go h2.Spin() 939 time.Sleep(100 * time.Millisecond) 940 941 _, err = hc.Get("http://127.0.0.1:9448/bind?a=") 942 assert.Nil(t, err) 943 time.Sleep(100 * time.Millisecond) 944 } 945 946 type mockBinder struct{} 947 948 func (m *mockBinder) Name() string { 949 return "test binder" 950 } 951 952 func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { 953 return nil 954 } 955 956 func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { 957 return fmt.Errorf("test binder") 958 } 959 960 func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { 961 return nil 962 } 963 964 func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { 965 return nil 966 } 967 968 func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { 969 return nil 970 } 971 972 func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { 973 return nil 974 } 975 976 func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { 977 return nil 978 } 979 980 func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { 981 return nil 982 } 983 984 func TestCustomBinder(t *testing.T) { 985 type Req struct { 986 A int `query:"a"` 987 } 988 h := New( 989 WithHostPorts("localhost:9334"), 990 WithCustomBinder(&mockBinder{})) 991 h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 992 var req Req 993 err := ctx.BindAndValidate(&req) 994 if err == nil { 995 t.Fatal("expect an error") 996 } 997 assert.DeepEqual(t, "test binder", err.Error()) 998 }) 999 1000 go h.Spin() 1001 time.Sleep(100 * time.Millisecond) 1002 hc := http.Client{Timeout: time.Second} 1003 _, err := hc.Get("http://127.0.0.1:9334/bind?a=") 1004 assert.Nil(t, err) 1005 time.Sleep(100 * time.Millisecond) 1006 } 1007 1008 func TestValidateConfigRegValidateFunc(t *testing.T) { 1009 type Req struct { 1010 A int `query:"a" vd:"f($)"` 1011 } 1012 validateConfig := &binding.ValidateConfig{} 1013 validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { 1014 return fmt.Errorf("test validator") 1015 }) 1016 h := New( 1017 WithHostPorts("localhost:9229")) 1018 h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 1019 var req Req 1020 err := ctx.BindAndValidate(&req) 1021 if err == nil { 1022 t.Fatal("expect an error") 1023 } 1024 assert.DeepEqual(t, "test validator", err.Error()) 1025 }) 1026 1027 go h.Spin() 1028 time.Sleep(100 * time.Millisecond) 1029 hc := http.Client{Timeout: time.Second} 1030 _, err := hc.Get("http://127.0.0.1:9229/bind?a=2") 1031 assert.Nil(t, err) 1032 time.Sleep(100 * time.Millisecond) 1033 } 1034 1035 type mockValidator struct{} 1036 1037 func (m *mockValidator) ValidateStruct(interface{}) error { 1038 return fmt.Errorf("test mock validator") 1039 } 1040 1041 func (m *mockValidator) Engine() interface{} { 1042 return nil 1043 } 1044 1045 func (m *mockValidator) ValidateTag() string { 1046 return "vd" 1047 } 1048 1049 func TestCustomValidator(t *testing.T) { 1050 type Req struct { 1051 A int `query:"a" vd:"f($)"` 1052 } 1053 h := New( 1054 WithHostPorts("localhost:9555"), 1055 WithCustomValidator(&mockValidator{})) 1056 h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 1057 var req Req 1058 err := ctx.BindAndValidate(&req) 1059 if err == nil { 1060 t.Fatal("expect an error") 1061 } 1062 assert.DeepEqual(t, "test mock validator", err.Error()) 1063 }) 1064 1065 go h.Spin() 1066 time.Sleep(100 * time.Millisecond) 1067 hc := http.Client{Timeout: time.Second} 1068 _, err := hc.Get("http://127.0.0.1:9555/bind?a=2") 1069 assert.Nil(t, err) 1070 time.Sleep(100 * time.Millisecond) 1071 } 1072 1073 type ValidateError struct { 1074 ErrType, FailField, Msg string 1075 } 1076 1077 // Error implements error interface. 1078 func (e *ValidateError) Error() string { 1079 if e.Msg != "" { 1080 return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg 1081 } 1082 return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" 1083 } 1084 1085 func TestValidateConfigSetSetErrorFactory(t *testing.T) { 1086 type TestValidate struct { 1087 B int `query:"b" vd:"$>100"` 1088 } 1089 CustomValidateErrFunc := func(failField, msg string) error { 1090 err := ValidateError{ 1091 ErrType: "validateErr", 1092 FailField: "[validateFailField]: " + failField, 1093 Msg: "[validateErrMsg]: " + msg, 1094 } 1095 1096 return &err 1097 } 1098 validateConfig := binding.NewValidateConfig() 1099 validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) 1100 h := New( 1101 WithHostPorts("localhost:9666"), 1102 WithValidateConfig(validateConfig)) 1103 h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 1104 var req TestValidate 1105 err := ctx.BindAndValidate(&req) 1106 if err == nil { 1107 t.Fatal("expect an error") 1108 } 1109 assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) 1110 }) 1111 1112 go h.Spin() 1113 time.Sleep(100 * time.Millisecond) 1114 hc := http.Client{Timeout: time.Second} 1115 _, err := hc.Get("http://127.0.0.1:9666/bind?b=1") 1116 assert.Nil(t, err) 1117 time.Sleep(100 * time.Millisecond) 1118 } 1119 1120 func TestValidateConfigAndBindConfig(t *testing.T) { 1121 type Req struct { 1122 A int `query:"a" vt:"$>=0&&$<=130"` 1123 } 1124 validateConfig := binding.NewValidateConfig() 1125 validateConfig.ValidateTag = "vt" 1126 h := New( 1127 WithHostPorts("localhost:9876"), 1128 WithValidateConfig(validateConfig)) 1129 h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { 1130 var req Req 1131 err := ctx.BindAndValidate(&req) 1132 if err == nil { 1133 t.Fatal("expect an error") 1134 } 1135 t.Log(err) 1136 }) 1137 1138 go h.Spin() 1139 time.Sleep(100 * time.Millisecond) 1140 hc := http.Client{Timeout: time.Second} 1141 _, err := hc.Get("http://127.0.0.1:9876/bind?a=135") 1142 assert.Nil(t, err) 1143 time.Sleep(100 * time.Millisecond) 1144 } 1145 1146 func TestWithDisableDefaultDate(t *testing.T) { 1147 h := New( 1148 WithHostPorts("localhost:8321"), 1149 WithDisableDefaultDate(true), 1150 ) 1151 h.GET("/", func(_ context.Context, c *app.RequestContext) {}) 1152 go h.Spin() 1153 time.Sleep(100 * time.Millisecond) 1154 hc := http.Client{Timeout: time.Second} 1155 r, _ := hc.Get("http://127.0.0.1:8321") //nolint:errcheck 1156 assert.DeepEqual(t, "", r.Header.Get("Date")) 1157 } 1158 1159 func TestWithDisableDefaultContentType(t *testing.T) { 1160 h := New( 1161 WithHostPorts("localhost:8324"), 1162 WithDisableDefaultContentType(true), 1163 ) 1164 h.GET("/", func(_ context.Context, c *app.RequestContext) {}) 1165 go h.Spin() 1166 time.Sleep(100 * time.Millisecond) 1167 hc := http.Client{Timeout: time.Second} 1168 r, _ := hc.Get("http://127.0.0.1:8324") //nolint:errcheck 1169 assert.DeepEqual(t, "", r.Header.Get("Content-Type")) 1170 }