github.com/cloudwego/hertz@v0.9.3/pkg/app/client/client_test.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 * The MIT License (MIT) 17 * 18 * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors 19 * 20 * Permission is hereby granted, free of charge, to any person obtaining a copy 21 * of this software and associated documentation files (the "Software"), to deal 22 * in the Software without restriction, including without limitation the rights 23 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 24 * copies of the Software, and to permit persons to whom the Software is 25 * furnished to do so, subject to the following conditions: 26 * 27 * The above copyright notice and this permission notice shall be included in 28 * all copies or substantial portions of the Software. 29 * 30 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 31 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 32 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 33 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 34 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 35 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 36 * THE SOFTWARE. 37 * 38 * This file may have been modified by CloudWeGo authors. All CloudWeGo 39 * Modifications are Copyright 2022 CloudWeGo Authors. 40 */ 41 42 package client 43 44 import ( 45 "context" 46 "crypto/tls" 47 "encoding/base64" 48 "errors" 49 "fmt" 50 "io" 51 "io/ioutil" 52 "net" 53 "net/http" 54 "net/http/httptest" 55 "net/url" 56 "os" 57 "path/filepath" 58 "reflect" 59 "regexp" 60 "strings" 61 "sync" 62 "sync/atomic" 63 "testing" 64 "time" 65 66 "github.com/cloudwego/hertz/internal/bytestr" 67 "github.com/cloudwego/hertz/pkg/app" 68 "github.com/cloudwego/hertz/pkg/app/client/retry" 69 "github.com/cloudwego/hertz/pkg/common/config" 70 errs "github.com/cloudwego/hertz/pkg/common/errors" 71 "github.com/cloudwego/hertz/pkg/common/test/assert" 72 "github.com/cloudwego/hertz/pkg/network" 73 "github.com/cloudwego/hertz/pkg/network/dialer" 74 "github.com/cloudwego/hertz/pkg/network/standard" 75 "github.com/cloudwego/hertz/pkg/protocol" 76 "github.com/cloudwego/hertz/pkg/protocol/consts" 77 "github.com/cloudwego/hertz/pkg/protocol/http1" 78 "github.com/cloudwego/hertz/pkg/protocol/http1/req" 79 "github.com/cloudwego/hertz/pkg/protocol/http1/resp" 80 "github.com/cloudwego/hertz/pkg/route" 81 ) 82 83 var errTooManyRedirects = errors.New("too many redirects detected when doing the request") 84 85 func TestCloseIdleConnections(t *testing.T) { 86 opt := config.NewOptions([]config.Option{}) 87 opt.Addr = "unix-test-10000" 88 opt.Network = "unix" 89 engine := route.NewEngine(opt) 90 91 go engine.Run() 92 defer func() { 93 engine.Close() 94 }() 95 for { 96 time.Sleep(1 * time.Second) 97 if engine.IsRunning() { 98 break 99 } 100 } 101 102 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 103 104 if _, _, err := c.Get(context.Background(), nil, "http://google.com"); err != nil { 105 t.Fatal(err) 106 } 107 108 connsLen := func() int { 109 c.mLock.Lock() 110 defer c.mLock.Unlock() 111 112 if _, ok := c.m["google.com"]; !ok { 113 return 0 114 } 115 return c.m["google.com"].ConnectionCount() 116 } 117 118 if conns := connsLen(); conns > 1 { 119 t.Errorf("expected 1 conns got %d", conns) 120 } 121 122 c.CloseIdleConnections() 123 124 if conns := connsLen(); conns > 0 { 125 t.Errorf("expected 0 conns got %d", conns) 126 } 127 } 128 129 func TestClientInvalidURI(t *testing.T) { 130 t.Parallel() 131 132 opt := config.NewOptions([]config.Option{}) 133 opt.Addr = "unix-test-10001" 134 opt.Network = "unix" 135 requests := int64(0) 136 engine := route.NewEngine(opt) 137 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 138 atomic.AddInt64(&requests, 1) 139 }) 140 go engine.Run() 141 defer func() { 142 engine.Close() 143 }() 144 for { 145 time.Sleep(1 * time.Second) 146 if engine.IsRunning() { 147 break 148 } 149 } 150 151 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 152 req, res := protocol.AcquireRequest(), protocol.AcquireResponse() 153 defer func() { 154 protocol.ReleaseRequest(req) 155 protocol.ReleaseResponse(res) 156 }() 157 req.Header.SetMethod(consts.MethodGet) 158 req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n") 159 err := c.Do(context.Background(), req, res) 160 if err == nil { 161 t.Fatal("expected error (missing required Host header in request)") 162 } 163 if n := atomic.LoadInt64(&requests); n != 0 { 164 t.Fatalf("0 requests expected, got %d", n) 165 } 166 } 167 168 func TestClientGetWithBody(t *testing.T) { 169 t.Parallel() 170 171 opt := config.NewOptions([]config.Option{}) 172 opt.Addr = "unix-test-10002" 173 opt.Network = "unix" 174 engine := route.NewEngine(opt) 175 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 176 body := ctx.Request.Body() 177 ctx.Write(body) //nolint:errcheck 178 }) 179 go engine.Run() 180 defer func() { 181 engine.Close() 182 }() 183 for { 184 time.Sleep(1 * time.Second) 185 if engine.IsRunning() { 186 break 187 } 188 } 189 190 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 191 req, res := protocol.AcquireRequest(), protocol.AcquireResponse() 192 defer func() { 193 protocol.ReleaseRequest(req) 194 protocol.ReleaseResponse(res) 195 }() 196 req.Header.SetMethod(consts.MethodGet) 197 req.SetRequestURI("http://example.com") 198 req.SetBodyString("test") 199 err := c.Do(context.Background(), req, res) 200 if err != nil { 201 t.Fatal(err) 202 } 203 if len(res.Body()) == 0 { 204 t.Fatal("missing request body") 205 } 206 } 207 208 func TestClientPostBodyStream(t *testing.T) { 209 t.Parallel() 210 211 opt := config.NewOptions([]config.Option{}) 212 opt.Addr = "unix-test-10102" 213 opt.Network = "unix" 214 engine := route.NewEngine(opt) 215 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 216 body := ctx.Request.Body() 217 ctx.Write(body) //nolint:errcheck 218 }) 219 go engine.Run() 220 defer func() { 221 engine.Close() 222 }() 223 for { 224 time.Sleep(1 * time.Second) 225 if engine.IsRunning() { 226 break 227 } 228 } 229 230 cStream, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil)), WithResponseBodyStream(true)) 231 args := &protocol.Args{} 232 // There is some data in databuf and others is in bodystream, so we need 233 // to let the data exceed the max bodysize of bodystream 234 v := "" 235 for i := 0; i < 10240; i++ { 236 v += "b" 237 } 238 args.Add("a", v) 239 _, body, err := cStream.Post(context.Background(), nil, "http://example.com", args) 240 if err != nil { 241 t.Fatal(err) 242 } 243 assert.DeepEqual(t, "a="+v, string(body)) 244 } 245 246 func TestClientURLAuth(t *testing.T) { 247 t.Parallel() 248 249 cases := map[string]string{ 250 "foo:bar@": "Basic Zm9vOmJhcg==", 251 "foo:@": "Basic Zm9vOg==", 252 ":@": "", 253 "@": "", 254 "": "", 255 } 256 ch := make(chan string, 1) 257 258 opt := config.NewOptions([]config.Option{}) 259 opt.Addr = "unix-test-10003" 260 opt.Network = "unix" 261 engine := route.NewEngine(opt) 262 engine.GET("/foo/bar", func(c context.Context, ctx *app.RequestContext) { 263 ch <- string(ctx.Request.Header.Peek(consts.HeaderAuthorization)) 264 }) 265 go engine.Run() 266 defer func() { 267 engine.Close() 268 }() 269 for { 270 time.Sleep(1 * time.Second) 271 if engine.IsRunning() { 272 break 273 } 274 } 275 276 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 277 for up, expected := range cases { 278 req := protocol.AcquireRequest() 279 req.Header.SetMethod(consts.MethodGet) 280 req.SetRequestURI("http://" + up + "example.com/foo/bar") 281 282 if err := c.Do(context.Background(), req, nil); err != nil { 283 t.Fatal(err) 284 } 285 286 val := <-ch 287 288 if val != expected { 289 t.Fatalf("wrong %s header: %s expected %s", consts.HeaderAuthorization, val, expected) 290 } 291 } 292 } 293 294 func TestClientNilResp(t *testing.T) { 295 opt := config.NewOptions([]config.Option{}) 296 opt.Addr = "unix-test-10004" 297 opt.Network = "unix" 298 engine := route.NewEngine(opt) 299 300 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 301 }) 302 go engine.Run() 303 defer func() { 304 engine.Close() 305 }() 306 for { 307 time.Sleep(1 * time.Second) 308 if engine.IsRunning() { 309 break 310 } 311 } 312 313 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 314 315 req := protocol.AcquireRequest() 316 req.Header.SetMethod(consts.MethodGet) 317 req.SetRequestURI("http://example.com") 318 if err := c.Do(context.Background(), req, nil); err != nil { 319 t.Fatal(err) 320 } 321 if err := c.DoTimeout(context.Background(), req, nil, time.Second); err != nil { 322 t.Fatal(err) 323 } 324 } 325 326 func TestClientParseConn(t *testing.T) { 327 t.Parallel() 328 opt := config.NewOptions([]config.Option{}) 329 opt.Addr = "127.0.0.1:10005" 330 engine := route.NewEngine(opt) 331 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 332 }) 333 go engine.Run() 334 defer func() { 335 engine.Close() 336 }() 337 for { 338 time.Sleep(1 * time.Second) 339 if engine.IsRunning() { 340 break 341 } 342 } 343 344 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 345 req, res := protocol.AcquireRequest(), protocol.AcquireResponse() 346 defer func() { 347 protocol.ReleaseRequest(req) 348 protocol.ReleaseResponse(res) 349 }() 350 req.SetRequestURI("http://" + opt.Addr + "") 351 if err := c.Do(context.Background(), req, res); err != nil { 352 t.Fatal(err) 353 } 354 355 if res.RemoteAddr().Network() != opt.Network { 356 t.Fatalf("req RemoteAddr parse network fail: %s, hope: %s", res.RemoteAddr().Network(), opt.Network) 357 } 358 if opt.Addr != res.RemoteAddr().String() { 359 t.Fatalf("req RemoteAddr parse addr fail: %s, hope: %s", res.RemoteAddr().String(), opt.Addr) 360 } 361 362 if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) { 363 t.Fatalf("res LocalAddr addr match fail: %s, hope match: %s", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$") 364 } 365 } 366 367 func TestClientPostArgs(t *testing.T) { 368 t.Parallel() 369 opt := config.NewOptions([]config.Option{}) 370 opt.Addr = "unix-test-10006" 371 opt.Network = "unix" 372 engine := route.NewEngine(opt) 373 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 374 body := ctx.Request.Body() 375 if len(body) == 0 { 376 return 377 } 378 ctx.Write(body) //nolint:errcheck 379 }) 380 go engine.Run() 381 defer func() { 382 engine.Close() 383 }() 384 for { 385 time.Sleep(1 * time.Second) 386 if engine.IsRunning() { 387 break 388 } 389 } 390 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 391 req, res := protocol.AcquireRequest(), protocol.AcquireResponse() 392 defer func() { 393 protocol.ReleaseRequest(req) 394 protocol.ReleaseResponse(res) 395 }() 396 args := req.PostArgs() 397 args.Add("addhttp2", "support") 398 args.Add("fast", "http") 399 req.Header.SetMethod(consts.MethodPost) 400 req.SetRequestURI("http://make.hertz.great?again") 401 err := c.Do(context.Background(), req, res) 402 if err != nil { 403 t.Fatal(err) 404 } 405 if len(res.Body()) == 0 { 406 t.Fatal("cannot set args as body") 407 } 408 } 409 410 func TestClientHeaderCase(t *testing.T) { 411 t.Parallel() 412 413 opt := config.NewOptions([]config.Option{}) 414 opt.Addr = "unix-test-10007" 415 opt.Network = "unix" 416 engine := route.NewEngine(opt) 417 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 418 zw := ctx.GetWriter() 419 zw.WriteBinary([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck 420 "content-type: text/plain\r\n" + 421 "transfer-encoding: chunked\r\n\r\n" + 422 "24\r\nThis is the data in the first chunk \r\n" + 423 "1B\r\nand this is the second one \r\n" + 424 "0\r\n\r\n", 425 )) 426 }) 427 go engine.Run() 428 defer func() { 429 engine.Close() 430 }() 431 time.Sleep(time.Second) 432 433 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDisableHeaderNamesNormalizing(true)) 434 code, body, err := c.Get(context.Background(), nil, "http://example.com") 435 if err != nil { 436 t.Error(err) 437 } else if code != 200 { 438 t.Errorf("expected status code 200 got %d", code) 439 } else if string(body) != "This is the data in the first chunk and this is the second one " { 440 t.Errorf("wrong body: %q", body) 441 } 442 } 443 444 func TestClientReadTimeout(t *testing.T) { 445 if testing.Short() { 446 t.Skip("skipping test in short mode") 447 } 448 449 opt := config.NewOptions([]config.Option{}) 450 opt.Addr = "localhost:10024" 451 engine := route.NewEngine(opt) 452 453 engine.GET("/normal", func(c context.Context, ctx *app.RequestContext) { 454 ctx.String(201, "ok") 455 }) 456 engine.GET("/timeout", func(c context.Context, ctx *app.RequestContext) { 457 time.Sleep(time.Second * 60) 458 ctx.String(202, "timeout ok") 459 }) 460 go engine.Run() 461 defer func() { 462 engine.Close() 463 }() 464 time.Sleep(time.Second * 1) 465 466 c := &http1.HostClient{ 467 ClientOptions: &http1.ClientOptions{ 468 ReadTimeout: time.Second * 4, 469 RetryConfig: &retry.Config{MaxAttemptTimes: 1}, 470 Dialer: standard.NewDialer(), 471 }, 472 Addr: opt.Addr, 473 } 474 475 req := protocol.AcquireRequest() 476 res := protocol.AcquireResponse() 477 478 req.SetRequestURI("http://" + opt.Addr + "/normal") 479 req.Header.SetMethod(consts.MethodGet) 480 481 // Setting Connection: Close will make the connection be returned to the pool. 482 req.SetConnectionClose() 483 484 if err := c.Do(context.Background(), req, res); err != nil { 485 t.Fatal(err) 486 } 487 488 protocol.ReleaseRequest(req) 489 protocol.ReleaseResponse(res) 490 491 done := make(chan struct{}) 492 go func() { 493 req := protocol.AcquireRequest() 494 res := protocol.AcquireResponse() 495 496 req.SetRequestURI("http://" + opt.Addr + "/timeout") 497 req.Header.SetMethod(consts.MethodGet) 498 req.SetConnectionClose() 499 500 if err := c.Do(context.Background(), req, res); !errors.Is(err, errs.ErrTimeout) { 501 if err == nil { 502 t.Errorf("expected ErrTimeout got nil, req url: %s, read resp body: %s, status: %d", string(req.URI().FullURI()), string(res.Body()), res.StatusCode()) 503 } else { 504 if !strings.Contains(err.Error(), "timeout") { 505 t.Errorf("expected ErrTimeout got %#v", err) 506 } 507 } 508 } 509 510 protocol.ReleaseRequest(req) 511 protocol.ReleaseResponse(res) 512 close(done) 513 }() 514 515 select { 516 case <-done: 517 // It is abnormal when waiting time exceeds the value of readTimeout times the number of retries. 518 // Give it extra 2 seconds just to be sure. 519 case <-time.After(c.ReadTimeout*time.Duration(c.RetryConfig.MaxAttemptTimes) + time.Second*2): 520 t.Fatal("Client.ReadTimeout didn't work") 521 } 522 } 523 524 func TestClientDefaultUserAgent(t *testing.T) { 525 t.Parallel() 526 527 opt := config.NewOptions([]config.Option{}) 528 529 opt.Addr = "unix-test-10009" 530 opt.Network = "unix" 531 engine := route.NewEngine(opt) 532 533 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 534 ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", ctx.UserAgent()) 535 }) 536 go engine.Run() 537 defer func() { 538 engine.Close() 539 }() 540 for { 541 time.Sleep(1 * time.Second) 542 if engine.IsRunning() { 543 break 544 } 545 } 546 547 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 548 req := protocol.AcquireRequest() 549 res := protocol.AcquireResponse() 550 551 req.SetRequestURI("http://example.com") 552 req.Header.SetMethod(consts.MethodGet) 553 554 err := c.Do(context.Background(), req, res) 555 if err != nil { 556 t.Fatal(err) 557 } 558 if string(res.Body()) != string(bytestr.DefaultUserAgent) { 559 t.Fatalf("User-Agent defers %q != %q", string(res.Body()), bytestr.DefaultUserAgent) 560 } 561 } 562 563 func TestClientSetUserAgent(t *testing.T) { 564 t.Parallel() 565 566 opt := config.NewOptions([]config.Option{}) 567 568 opt.Addr = "unix-test-10010" 569 opt.Network = "unix" 570 engine := route.NewEngine(opt) 571 572 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 573 ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", ctx.UserAgent()) 574 }) 575 go engine.Run() 576 defer func() { 577 engine.Close() 578 }() 579 for { 580 time.Sleep(1 * time.Second) 581 if engine.IsRunning() { 582 break 583 } 584 } 585 586 userAgent := "I'm not hertz" 587 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithName(userAgent)) 588 req := protocol.AcquireRequest() 589 res := protocol.AcquireResponse() 590 591 req.SetRequestURI("http://example.com") 592 593 err := c.Do(context.Background(), req, res) 594 if err != nil { 595 t.Fatal(err) 596 } 597 if string(res.Body()) != userAgent { 598 t.Fatalf("User-Agent defers %q != %q", string(res.Body()), userAgent) 599 } 600 } 601 602 func TestClientNoUserAgent(t *testing.T) { 603 opt := config.NewOptions([]config.Option{}) 604 opt.Addr = "unix-test-10011" 605 opt.Network = "unix" 606 engine := route.NewEngine(opt) 607 608 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 609 ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", ctx.UserAgent()) 610 }) 611 go engine.Run() 612 defer func() { 613 engine.Close() 614 }() 615 for { 616 time.Sleep(1 * time.Second) 617 if engine.IsRunning() { 618 break 619 } 620 } 621 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDialTimeout(1*time.Second), WithNoDefaultUserAgentHeader(true)) 622 623 req := protocol.AcquireRequest() 624 res := protocol.AcquireResponse() 625 626 req.SetRequestURI("http://example.com") 627 628 err := c.Do(context.Background(), req, res) 629 if err != nil { 630 t.Fatal(err) 631 } 632 if string(res.Body()) != "" { 633 t.Fatalf("User-Agent wrong %q != %q", string(res.Body()), "") 634 } 635 } 636 637 func TestClientDoWithCustomHeaders(t *testing.T) { 638 t.Parallel() 639 640 ch := make(chan error) 641 uri := "/foo/bar/baz?a=b&cd=12" 642 headers := map[string]string{ 643 "Foo": "bar", 644 "Host": "xxx.com", 645 "Content-Type": "asdfsdf", 646 "a-b-c-d-f": "", 647 } 648 body := "request body" 649 opt := config.NewOptions([]config.Option{}) 650 651 opt.Addr = "unix-test-10012" 652 opt.Network = "unix" 653 engine := route.NewEngine(opt) 654 655 engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) { 656 zw := ctx.GetWriter() 657 658 if string(ctx.Request.Header.Method()) != consts.MethodPost { 659 ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", ctx.Request.Header.Method(), consts.MethodPost) 660 return 661 } 662 reqURI := ctx.Request.RequestURI() 663 if string(reqURI) != uri { 664 ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri) 665 return 666 } 667 for k, v := range headers { 668 hv := ctx.Request.Header.Peek(k) 669 if string(hv) != v { 670 ch <- fmt.Errorf("unexpected value for header %q: %q. Expecting %q", k, hv, v) 671 return 672 } 673 } 674 cl := ctx.Request.Header.ContentLength() 675 if cl != len(body) { 676 ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body)) 677 return 678 } 679 reqBody := ctx.Request.Body() 680 if string(reqBody) != body { 681 ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body) 682 return 683 } 684 685 var r protocol.Response 686 if err := resp.Write(&r, zw); err != nil { 687 ch <- fmt.Errorf("cannot send response: %s", err) 688 return 689 } 690 if err := zw.Flush(); err != nil { 691 ch <- fmt.Errorf("cannot flush response: %s", err) 692 return 693 } 694 695 ch <- nil 696 }) 697 go engine.Run() 698 defer func() { 699 engine.Close() 700 }() 701 for { 702 time.Sleep(1 * time.Second) 703 if engine.IsRunning() { 704 break 705 } 706 } 707 708 // make sure that the client sends all the request headers and body. 709 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) 710 711 var req protocol.Request 712 req.Header.SetMethod(consts.MethodPost) 713 req.SetRequestURI(uri) 714 for k, v := range headers { 715 req.Header.Set(k, v) 716 } 717 req.SetBodyString(body) 718 719 var resp protocol.Response 720 721 err := c.DoTimeout(context.Background(), &req, &resp, time.Second) 722 if err != nil { 723 t.Fatalf("error when doing request: %s", err) 724 } 725 726 select { 727 case <-ch: 728 case <-time.After(5 * time.Second): 729 t.Fatalf("timeout") 730 } 731 } 732 733 func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { 734 t.Parallel() 735 opt := config.NewOptions([]config.Option{}) 736 737 opt.Addr = "unix-test-10013" 738 opt.Network = "unix" 739 engine := route.NewEngine(opt) 740 741 engine.Use(func(c context.Context, ctx *app.RequestContext) { 742 uri := ctx.URI() 743 uri.DisablePathNormalizing = true 744 ctx.Response.Header.Set("received-uri", string(uri.FullURI())) 745 }) 746 747 go engine.Run() 748 defer func() { 749 engine.Close() 750 }() 751 for { 752 time.Sleep(1 * time.Second) 753 if engine.IsRunning() { 754 break 755 } 756 } 757 c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDisablePathNormalizing(true)) 758 759 urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff" 760 761 var req protocol.Request 762 req.SetRequestURI(urlWithEncodedPath) 763 var resp protocol.Response 764 for i := 0; i < 5; i++ { 765 if err := c.DoTimeout(context.Background(), &req, &resp, time.Second); err != nil { 766 t.Fatalf("unexpected error: %s", err) 767 } 768 hv := resp.Header.Peek("received-uri") 769 if string(hv) != urlWithEncodedPath { 770 t.Fatalf("request uri was normalized: %q. Expecting %q", hv, urlWithEncodedPath) 771 } 772 } 773 } 774 775 func TestHostClientPendingRequests(t *testing.T) { 776 t.Parallel() 777 778 const concurrency = 10 779 doneCh := make(chan struct{}) 780 readyCh := make(chan struct{}, concurrency) 781 opt := config.NewOptions([]config.Option{}) 782 783 opt.Addr = "unix-test-10014" 784 opt.Network = "unix" 785 engine := route.NewEngine(opt) 786 787 engine.GET("/baz", func(c context.Context, ctx *app.RequestContext) { 788 readyCh <- struct{}{} 789 <-doneCh 790 }) 791 go engine.Run() 792 defer func() { 793 engine.Close() 794 }() 795 time.Sleep(time.Second) 796 797 c := &http1.HostClient{ 798 ClientOptions: &http1.ClientOptions{ 799 Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), 800 }, 801 Addr: "foobar", 802 } 803 804 pendingRequests := c.PendingRequests() 805 if pendingRequests != 0 { 806 t.Fatalf("non-zero pendingRequests: %d", pendingRequests) 807 } 808 809 resultCh := make(chan error, concurrency) 810 for i := 0; i < concurrency; i++ { 811 go func() { 812 req := protocol.AcquireRequest() 813 req.SetRequestURI("http://foobar/baz") 814 req.Header.SetMethod(consts.MethodGet) 815 resp := protocol.AcquireResponse() 816 817 if err := c.DoTimeout(context.Background(), req, resp, 10*time.Second); err != nil { 818 resultCh <- fmt.Errorf("unexpected error: %s", err) 819 return 820 } 821 822 if resp.StatusCode() != consts.StatusOK { 823 resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) 824 return 825 } 826 resultCh <- nil 827 }() 828 } 829 830 // wait until all the requests reach server 831 for i := 0; i < concurrency; i++ { 832 select { 833 case <-readyCh: 834 case <-time.After(time.Second): 835 t.Fatalf("timeout") 836 } 837 } 838 839 pendingRequests = c.PendingRequests() 840 if pendingRequests != concurrency { 841 t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency) 842 } 843 844 // unblock request handlers on the server and wait until all the requests are finished. 845 close(doneCh) 846 for i := 0; i < concurrency; i++ { 847 select { 848 case err := <-resultCh: 849 if err != nil { 850 t.Fatalf("unexpected error: %s", err) 851 } 852 case <-time.After(time.Second): 853 t.Fatalf("timeout") 854 } 855 } 856 857 pendingRequests = c.PendingRequests() 858 if pendingRequests != 0 { 859 t.Fatalf("non-zero pendingRequests: %d", pendingRequests) 860 } 861 } 862 863 func TestHostClientMaxConnsWithDeadline(t *testing.T) { 864 var ( 865 emptyBodyCount uint8 866 timeout = 50 * time.Millisecond 867 wg sync.WaitGroup 868 ) 869 opt := config.NewOptions([]config.Option{}) 870 871 opt.Addr = "unix-test-10015" 872 opt.Network = "unix" 873 engine := route.NewEngine(opt) 874 875 engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { 876 if len(ctx.Request.Body()) == 0 { 877 emptyBodyCount++ 878 } 879 880 ctx.WriteString("foo") //nolint:errcheck 881 }) 882 go engine.Run() 883 defer func() { 884 engine.Close() 885 }() 886 for { 887 time.Sleep(1 * time.Second) 888 if engine.IsRunning() { 889 break 890 } 891 } 892 893 c := &http1.HostClient{ 894 ClientOptions: &http1.ClientOptions{ 895 Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), 896 MaxConns: 1, 897 }, 898 Addr: "foobar", 899 } 900 901 for i := 0; i < 5; i++ { 902 wg.Add(1) 903 go func() { 904 defer wg.Done() 905 906 req := protocol.AcquireRequest() 907 req.SetRequestURI("http://foobar/baz") 908 req.Header.SetMethod(consts.MethodPost) 909 req.SetBodyString("bar") 910 resp := protocol.AcquireResponse() 911 912 for { 913 if err := c.DoDeadline(context.Background(), req, resp, time.Now().Add(timeout)); err != nil { 914 if err.Error() == errs.ErrNoFreeConns.Error() { 915 time.Sleep(time.Millisecond * 500) 916 continue 917 } 918 t.Errorf("unexpected error: %s", err) 919 } 920 break 921 } 922 923 if resp.StatusCode() != consts.StatusOK { 924 t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) 925 } 926 927 body := resp.Body() 928 if string(body) != "foo" { 929 t.Errorf("unexpected body %q. Expecting %q", body, "abcd") 930 } 931 }() 932 } 933 wg.Wait() 934 935 if emptyBodyCount > 0 { 936 t.Fatalf("at least one request body was empty") 937 } 938 } 939 940 func TestHostClientMaxConnDuration(t *testing.T) { 941 t.Parallel() 942 943 connectionCloseCount := uint32(0) 944 opt := config.NewOptions([]config.Option{}) 945 946 opt.Addr = "unix-test-10016" 947 opt.Network = "unix" 948 engine := route.NewEngine(opt) 949 950 engine.GET("/bbb/cc", func(c context.Context, ctx *app.RequestContext) { 951 ctx.WriteString("abcd") //nolint:errcheck 952 if ctx.Request.ConnectionClose() { 953 atomic.AddUint32(&connectionCloseCount, 1) 954 } 955 }) 956 go engine.Run() 957 defer func() { 958 engine.Close() 959 }() 960 for { 961 time.Sleep(1 * time.Second) 962 if engine.IsRunning() { 963 break 964 } 965 } 966 967 c := &http1.HostClient{ 968 ClientOptions: &http1.ClientOptions{ 969 Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), 970 MaxConnDuration: 10 * time.Millisecond, 971 }, 972 Addr: "foobar", 973 } 974 975 for i := 0; i < 5; i++ { 976 statusCode, body, err := c.Get(context.Background(), nil, "http://aaaa.com/bbb/cc") 977 if err != nil { 978 t.Fatalf("unexpected error: %s", err) 979 } 980 if statusCode != consts.StatusOK { 981 t.Fatalf("unexpected status code %d. Expecting %d", statusCode, consts.StatusOK) 982 } 983 if string(body) != "abcd" { 984 t.Fatalf("unexpected body %q. Expecting %q", body, "abcd") 985 } 986 time.Sleep(c.MaxConnDuration) 987 } 988 989 if atomic.LoadUint32(&connectionCloseCount) == 0 { 990 t.Fatalf("expecting at least one 'Connection: close' request header") 991 } 992 } 993 994 func TestHostClientMultipleAddrs(t *testing.T) { 995 t.Parallel() 996 opt := config.NewOptions([]config.Option{}) 997 998 opt.Addr = "unix-test-10017" 999 opt.Network = "unix" 1000 engine := route.NewEngine(opt) 1001 1002 engine.GET("/baz/aaa", func(c context.Context, ctx *app.RequestContext) { 1003 ctx.Write(ctx.Host()) //nolint:errcheck 1004 ctx.SetConnectionClose() 1005 }) 1006 go engine.Run() 1007 defer func() { 1008 engine.Close() 1009 }() 1010 for { 1011 time.Sleep(1 * time.Second) 1012 if engine.IsRunning() { 1013 break 1014 } 1015 } 1016 1017 dialsCount := make(map[string]int) 1018 c := &http1.HostClient{ 1019 ClientOptions: &http1.ClientOptions{ 1020 Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, func(network, addr string, timeout time.Duration, tlsConfig *tls.Config) { 1021 dialsCount[addr]++ 1022 }), 1023 }, 1024 Addr: "foo,bar,baz", 1025 } 1026 1027 for i := 0; i < 9; i++ { 1028 statusCode, body, err := c.Get(context.Background(), nil, "http://foobar/baz/aaa?bbb=ddd") 1029 if err != nil { 1030 t.Fatalf("unexpected error: %s", err) 1031 } 1032 if statusCode != consts.StatusOK { 1033 t.Fatalf("unexpected status code %d. Expecting %d", statusCode, consts.StatusOK) 1034 } 1035 if string(body) != "foobar" { 1036 t.Fatalf("unexpected body %q. Expecting %q", body, "foobar") 1037 } 1038 } 1039 1040 if len(dialsCount) != 3 { 1041 t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount)) 1042 } 1043 for _, k := range []string{"foo", "bar", "baz"} { 1044 if dialsCount[k] != 3 { 1045 t.Fatalf("unexpected dialsCount for %q. Expecting 3", k) 1046 } 1047 } 1048 } 1049 1050 func TestClientFollowRedirects(t *testing.T) { 1051 t.Parallel() 1052 opt := config.NewOptions([]config.Option{}) 1053 1054 opt.Addr = "unix-test-10018" 1055 opt.Network = "unix" 1056 engine := route.NewEngine(opt) 1057 1058 handler := func(c context.Context, ctx *app.RequestContext) { 1059 switch string(ctx.Path()) { 1060 case "/foo": 1061 u := ctx.URI() 1062 u.Update("/xy?z=wer") 1063 ctx.Redirect(consts.StatusFound, u.FullURI()) 1064 case "/xy": 1065 u := ctx.URI() 1066 u.Update("/bar") 1067 ctx.Redirect(consts.StatusFound, u.FullURI()) 1068 default: 1069 ctx.SetContentType(consts.MIMETextPlain) 1070 ctx.Response.SetBody(ctx.Path()) 1071 } 1072 } 1073 engine.GET("/foo", handler) 1074 engine.GET("/xy", handler) 1075 engine.GET("/bar", handler) 1076 engine.GET("/aaab/sss", handler) 1077 1078 go engine.Run() 1079 defer func() { 1080 engine.Close() 1081 }() 1082 time.Sleep(time.Second * 2) 1083 1084 c := &http1.HostClient{ 1085 ClientOptions: &http1.ClientOptions{ 1086 Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil), 1087 }, 1088 Addr: "xxx", 1089 } 1090 1091 for i := 0; i < 10; i++ { 1092 statusCode, body, err := c.GetTimeout(context.Background(), nil, "http://xxx/foo", time.Second) 1093 if err != nil { 1094 t.Fatalf("unexpected error: %s", err) 1095 } 1096 if statusCode != consts.StatusOK { 1097 t.Fatalf("unexpected status code: %d", statusCode) 1098 } 1099 if string(body) != "/bar" { 1100 t.Fatalf("unexpected response %q. Expecting %q", body, "/bar") 1101 } 1102 } 1103 1104 for i := 0; i < 10; i++ { 1105 statusCode, body, err := c.Get(context.Background(), nil, "http://xxx/aaab/sss") 1106 if err != nil { 1107 t.Fatalf("unexpected error: %s", err) 1108 } 1109 if statusCode != consts.StatusOK { 1110 t.Fatalf("unexpected status code: %d", statusCode) 1111 } 1112 if string(body) != "/aaab/sss" { 1113 t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss") 1114 } 1115 } 1116 1117 for i := 0; i < 10; i++ { 1118 req := protocol.AcquireRequest() 1119 resp := protocol.AcquireResponse() 1120 1121 req.SetRequestURI("http://xxx/foo") 1122 1123 err := c.DoRedirects(context.Background(), req, resp, 16) 1124 if err != nil { 1125 t.Fatalf("unexpected error: %s", err) 1126 } 1127 1128 if statusCode := resp.StatusCode(); statusCode != consts.StatusOK { 1129 t.Fatalf("unexpected status code: %d", statusCode) 1130 } 1131 1132 if body := string(resp.Body()); body != "/bar" { 1133 t.Fatalf("unexpected response %q. Expecting %q", body, "/bar") 1134 } 1135 1136 protocol.ReleaseRequest(req) 1137 protocol.ReleaseResponse(resp) 1138 } 1139 1140 req := protocol.AcquireRequest() 1141 resp := protocol.AcquireResponse() 1142 1143 req.SetRequestURI("http://xxx/foo") 1144 1145 err := c.DoRedirects(context.Background(), req, resp, 0) 1146 if have, want := err, errTooManyRedirects; have.Error() != want.Error() { 1147 t.Fatalf("want error: %v, have %v", want, have) 1148 } 1149 1150 protocol.ReleaseRequest(req) 1151 protocol.ReleaseResponse(resp) 1152 } 1153 1154 func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) { 1155 var ( 1156 emptyBodyCount uint8 1157 wg sync.WaitGroup 1158 ) 1159 opt := config.NewOptions([]config.Option{}) 1160 1161 opt.Addr = "unix-test-10019" 1162 opt.Network = "unix" 1163 engine := route.NewEngine(opt) 1164 1165 engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { 1166 if len(ctx.Request.Body()) == 0 { 1167 emptyBodyCount++ 1168 } 1169 time.Sleep(5 * time.Millisecond) 1170 ctx.WriteString("foo") //nolint:errcheck 1171 }) 1172 go engine.Run() 1173 defer func() { 1174 engine.Close() 1175 }() 1176 for { 1177 time.Sleep(1 * time.Second) 1178 if engine.IsRunning() { 1179 break 1180 } 1181 } 1182 1183 c := &http1.HostClient{ 1184 ClientOptions: &http1.ClientOptions{ 1185 Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), 1186 MaxConns: 1, 1187 MaxConnWaitTimeout: 200 * time.Millisecond, 1188 }, 1189 Addr: "foobar", 1190 } 1191 1192 for i := 0; i < 5; i++ { 1193 wg.Add(1) 1194 go func() { 1195 defer wg.Done() 1196 1197 req := protocol.AcquireRequest() 1198 req.SetRequestURI("http://foobar/baz") 1199 req.Header.SetMethod(consts.MethodPost) 1200 req.SetBodyString("bar") 1201 resp := protocol.AcquireResponse() 1202 1203 if err := c.Do(context.Background(), req, resp); err != nil { 1204 t.Errorf("unexpected error: %s", err) 1205 } 1206 1207 if resp.StatusCode() != consts.StatusOK { 1208 t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) 1209 } 1210 1211 body := resp.Body() 1212 if string(body) != "foo" { 1213 t.Errorf("unexpected body %q. Expecting %q", body, "abcd") 1214 } 1215 }() 1216 } 1217 wg.Wait() 1218 1219 if c.WantConnectionCount() > 0 { 1220 t.Errorf("connsWait has %v items remaining", c.WantConnectionCount()) 1221 } 1222 1223 if emptyBodyCount > 0 { 1224 t.Fatalf("at least one request body was empty") 1225 } 1226 } 1227 1228 func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { 1229 var ( 1230 emptyBodyCount uint8 1231 wg sync.WaitGroup 1232 ) 1233 opt := config.NewOptions([]config.Option{}) 1234 1235 opt.Addr = "unix-test-10020" 1236 opt.Network = "unix" 1237 engine := route.NewEngine(opt) 1238 1239 engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { 1240 if len(ctx.Request.Body()) == 0 { 1241 emptyBodyCount++ 1242 } 1243 time.Sleep(5 * time.Millisecond) 1244 ctx.WriteString("foo") //nolint:errcheck 1245 }) 1246 go engine.Run() 1247 defer func() { 1248 engine.Close() 1249 }() 1250 for { 1251 time.Sleep(1 * time.Second) 1252 if engine.IsRunning() { 1253 break 1254 } 1255 } 1256 1257 c := &http1.HostClient{ 1258 ClientOptions: &http1.ClientOptions{ 1259 Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), 1260 MaxConns: 1, 1261 MaxConnWaitTimeout: 10 * time.Millisecond, 1262 }, 1263 Addr: "foobar", 1264 } 1265 1266 var errNoFreeConnsCount uint32 1267 for i := 0; i < 5; i++ { 1268 wg.Add(1) 1269 go func() { 1270 defer wg.Done() 1271 1272 req := protocol.AcquireRequest() 1273 req.SetRequestURI("http://foobar/baz") 1274 req.Header.SetMethod(consts.MethodPost) 1275 req.SetBodyString("bar") 1276 resp := protocol.AcquireResponse() 1277 1278 if err := c.Do(context.Background(), req, resp); err != nil { 1279 if err.Error() != errs.ErrNoFreeConns.Error() { 1280 t.Errorf("unexpected error: %s. Expecting %s", err.Error(), errs.ErrNoFreeConns.Error()) 1281 } 1282 atomic.AddUint32(&errNoFreeConnsCount, 1) 1283 } else { 1284 if resp.StatusCode() != consts.StatusOK { 1285 t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) 1286 } 1287 1288 body := resp.Body() 1289 if string(body) != "foo" { 1290 t.Errorf("unexpected body %q. Expecting %q", body, "abcd") 1291 } 1292 } 1293 }() 1294 } 1295 wg.Wait() 1296 1297 if c.WantConnectionCount() > 0 { 1298 t.Errorf("connsWait has %v items remaining", c.WantConnectionCount()) 1299 } 1300 if errNoFreeConnsCount == 0 { 1301 t.Errorf("unexpected errorCount: %d. Expecting > 0", errNoFreeConnsCount) 1302 } 1303 1304 if emptyBodyCount > 0 { 1305 t.Fatalf("at least one request body was empty") 1306 } 1307 } 1308 1309 func TestNewClient(t *testing.T) { 1310 opt := config.NewOptions([]config.Option{}) 1311 opt.Addr = "127.0.0.1:10022" 1312 engine := route.NewEngine(opt) 1313 engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { 1314 ctx.SetBodyString("pong") 1315 }) 1316 go engine.Run() 1317 defer func() { 1318 engine.Close() 1319 }() 1320 time.Sleep(1 * time.Second) 1321 1322 client, err := NewClient(WithDialTimeout(2 * time.Second)) 1323 if err != nil { 1324 t.Fatal(err) 1325 return 1326 } 1327 status, resp, err := client.Get(context.Background(), nil, "http://127.0.0.1:10022/ping") 1328 if err != nil { 1329 t.Fatal(err) 1330 return 1331 } 1332 if status != consts.StatusOK { 1333 t.Errorf("return http status=%v", status) 1334 } 1335 t.Logf("resp=%v\n", string(resp)) 1336 } 1337 1338 func TestUseShortConnection(t *testing.T) { 1339 opt := config.NewOptions([]config.Option{}) 1340 opt.Addr = "127.0.0.1:10023" 1341 engine := route.NewEngine(opt) 1342 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 1343 }) 1344 go engine.Run() 1345 defer func() { 1346 engine.Close() 1347 }() 1348 time.Sleep(1 * time.Second) 1349 1350 c, _ := NewClient(WithKeepAlive(false)) 1351 var wg sync.WaitGroup 1352 for i := 0; i < 10; i++ { 1353 wg.Add(1) 1354 go func() { 1355 defer wg.Done() 1356 if _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10023"); err != nil { 1357 t.Error(err) 1358 return 1359 } 1360 }() 1361 } 1362 wg.Wait() 1363 connsLen := func() int { 1364 c.mLock.Lock() 1365 defer c.mLock.Unlock() 1366 1367 if _, ok := c.m["127.0.0.1:10023"]; !ok { 1368 return 0 1369 } 1370 1371 return c.m["127.0.0.1:10023"].ConnectionCount() 1372 } 1373 1374 if conns := connsLen(); conns > 0 { 1375 t.Errorf("expected 0 conns got %d", conns) 1376 } 1377 } 1378 1379 func TestPostWithFormData(t *testing.T) { 1380 opt := config.NewOptions([]config.Option{}) 1381 opt.Addr = "127.0.0.1:10025" 1382 engine := route.NewEngine(opt) 1383 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 1384 var ans string 1385 ctx.PostArgs().VisitAll(func(key, value []byte) { 1386 ans = ans + string(key) + "=" + string(value) + "&" 1387 }) 1388 ans = strings.TrimRight(ans, "&") 1389 ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", []byte(ans)) 1390 }) 1391 go engine.Run() 1392 defer func() { 1393 engine.Close() 1394 }() 1395 1396 time.Sleep(1 * time.Second) 1397 client, _ := NewClient() 1398 req := protocol.AcquireRequest() 1399 rsp := protocol.AcquireResponse() 1400 defer func() { 1401 protocol.ReleaseRequest(req) 1402 protocol.ReleaseResponse(rsp) 1403 }() 1404 postParam := map[string][]string{ 1405 "a": {"c", "d", "e"}, 1406 "b": {"c"}, 1407 "c": {"f"}, 1408 } 1409 req.SetFormData(map[string]string{ 1410 "a": "c", 1411 "b": "c", 1412 }) 1413 req.SetFormDataFromValues(url.Values{ 1414 "a": []string{"d", "e"}, 1415 "c": []string{"f"}, 1416 }) 1417 req.SetRequestURI("http://127.0.0.1:10025") 1418 req.SetMethod(consts.MethodPost) 1419 err := client.Do(context.Background(), req, rsp) 1420 if err != nil { 1421 t.Error(err) 1422 } 1423 for k, v := range postParam { 1424 for _, kv := range v { 1425 if !strings.Contains(string(rsp.Body()), k+"="+kv) { 1426 t.Errorf("miss %v=%v", k, kv) 1427 } 1428 } 1429 } 1430 } 1431 1432 func TestPostWithMultipartField(t *testing.T) { 1433 opt := config.NewOptions([]config.Option{}) 1434 opt.Addr = "127.0.0.1:10026" 1435 engine := route.NewEngine(opt) 1436 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 1437 if string(ctx.FormValue("a")) != "1" { 1438 t.Errorf("field a want 1, got %v", string(ctx.FormValue("a"))) 1439 } 1440 if string(ctx.FormValue("b")) != "2" { 1441 t.Errorf("field b want 2, got %v", string(ctx.FormValue("b"))) 1442 } 1443 t.Log(req.GetHTTP1Request(&ctx.Request).String()) 1444 }) 1445 go engine.Run() 1446 defer func() { 1447 engine.Close() 1448 }() 1449 1450 time.Sleep(1 * time.Second) 1451 client, _ := NewClient() 1452 req := protocol.AcquireRequest() 1453 rsp := protocol.AcquireResponse() 1454 defer func() { 1455 protocol.ReleaseRequest(req) 1456 protocol.ReleaseResponse(rsp) 1457 }() 1458 data := map[string]string{ 1459 "a": "1", 1460 "b": "2", 1461 } 1462 req.SetMethod(consts.MethodPost) 1463 req.SetRequestURI("http://127.0.0.1:10026") 1464 req.SetMultipartFormData(data) 1465 req.SetMultipartFormData(map[string]string{ 1466 "c": "3", 1467 }) 1468 err := client.DoTimeout(context.Background(), req, rsp, 1*time.Second) 1469 if err != nil { 1470 t.Error(err) 1471 } 1472 } 1473 1474 func TestSetFiles(t *testing.T) { 1475 opt := config.NewOptions([]config.Option{}) 1476 opt.Addr = "127.0.0.1:10027" 1477 engine := route.NewEngine(opt) 1478 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 1479 form, _ := ctx.MultipartForm() 1480 files := form.File["files"] 1481 // Upload the file to specific dst. 1482 for _, file := range files { 1483 ctx.SaveUploadedFile(file, filepath.Base(file.Filename)) 1484 } 1485 file1, _ := ctx.FormFile("file_1") 1486 ctx.SaveUploadedFile(file1, filepath.Base(file1.Filename)) 1487 file2, _ := ctx.FormFile("file_2") 1488 ctx.SaveUploadedFile(file2, filepath.Base(file2.Filename)) 1489 ctx.String(consts.StatusOK, fmt.Sprintf("%d files uploaded!", len(files)+2)) 1490 }) 1491 go engine.Run() 1492 defer func() { 1493 engine.Close() 1494 }() 1495 1496 time.Sleep(1 * time.Second) 1497 client, _ := NewClient() 1498 req := protocol.AcquireRequest() 1499 rsp := protocol.AcquireResponse() 1500 defer func() { 1501 protocol.ReleaseRequest(req) 1502 protocol.ReleaseResponse(rsp) 1503 }() 1504 req.SetMethod(consts.MethodPost) 1505 req.SetRequestURI("http://127.0.0.1:10027") 1506 files := []string{"../../common/testdata/test.txt", "../../common/testdata/proto/test.proto", "../../common/testdata/test.png", "../../common/testdata/proto/test.pb.go"} 1507 defer func() { 1508 for _, file := range files { 1509 os.Remove(filepath.Base(file)) 1510 } 1511 }() 1512 req.SetFile("files", files[0]) 1513 req.SetFile("files", files[1]) 1514 req.SetFiles(map[string]string{ 1515 "file_1": files[2], 1516 "file_2": files[3], 1517 }) 1518 err := client.DoTimeout(context.Background(), req, rsp, 1*time.Second) 1519 if err != nil { 1520 t.Error(err) 1521 } 1522 } 1523 1524 func TestSetMultipartFields(t *testing.T) { 1525 opt := config.NewOptions([]config.Option{}) 1526 opt.Addr = "127.0.0.1:10028" 1527 engine := route.NewEngine(opt) 1528 engine.POST("/", func(c context.Context, ctx *app.RequestContext) { 1529 t.Log(req.GetHTTP1Request(&ctx.Request).String()) 1530 if string(ctx.FormValue("a")) != "1" { 1531 t.Errorf("field a want 1, got %v", string(ctx.FormValue("a"))) 1532 } 1533 if string(ctx.FormValue("b")) != "2" { 1534 t.Errorf("field b want 2, got %v", string(ctx.FormValue("b"))) 1535 } 1536 file1, _ := ctx.FormFile("file_1") 1537 ctx.SaveUploadedFile(file1, filepath.Base(file1.Filename)) 1538 file2, _ := ctx.FormFile("file_2") 1539 ctx.SaveUploadedFile(file2, filepath.Base(file2.Filename)) 1540 ctx.String(consts.StatusOK, fmt.Sprintf("%d files uploaded!", 2)) 1541 }) 1542 go engine.Run() 1543 defer func() { 1544 engine.Close() 1545 }() 1546 1547 time.Sleep(1 * time.Second) 1548 client, _ := NewClient(WithDialTimeout(50 * time.Millisecond)) 1549 req := protocol.AcquireRequest() 1550 rsp := protocol.AcquireResponse() 1551 defer func() { 1552 protocol.ReleaseRequest(req) 1553 protocol.ReleaseResponse(rsp) 1554 }() 1555 jsonStr1 := `{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}` 1556 jsonStr2 := `{"input": {"name": "Uploaded document 2", "_filename" : ["file2.txt"]}}` 1557 files := []string{"upload-file-1.json", "upload-file-2.json"} 1558 fields := []*protocol.MultipartField{ 1559 { 1560 Param: "file_1", 1561 FileName: files[0], 1562 ContentType: consts.MIMEApplicationJSON, 1563 Reader: strings.NewReader(jsonStr1), 1564 }, 1565 { 1566 Param: "file_2", 1567 FileName: files[1], 1568 ContentType: consts.MIMEApplicationJSON, 1569 Reader: strings.NewReader(jsonStr2), 1570 }, 1571 } 1572 defer func() { 1573 for _, file := range files { 1574 os.Remove(filepath.Base(file)) 1575 } 1576 }() 1577 req.SetMultipartFields(fields...) 1578 req.SetMultipartFormData(map[string]string{"a": "1", "b": "2"}) 1579 req.SetRequestURI("http://127.0.0.1:10028") 1580 req.SetMethod(consts.MethodPost) 1581 err := client.DoTimeout(context.Background(), req, rsp, 1*time.Second) 1582 if err != nil { 1583 t.Error(err) 1584 } 1585 } 1586 1587 func TestClientReadResponseBodyStream(t *testing.T) { 1588 part1 := "abcdef" 1589 part2 := "ghij" 1590 1591 opt := config.NewOptions([]config.Option{}) 1592 opt.Addr = "127.0.0.1:10033" 1593 engine := route.NewEngine(opt) 1594 engine.POST("/", func(ctx context.Context, c *app.RequestContext) { 1595 c.String(consts.StatusOK, part1+part2) 1596 }) 1597 go engine.Run() 1598 defer func() { 1599 engine.Close() 1600 }() 1601 time.Sleep(1 * time.Second) 1602 1603 client, _ := NewClient(WithResponseBodyStream(true)) 1604 req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() 1605 defer func() { 1606 protocol.ReleaseRequest(req) 1607 protocol.ReleaseResponse(resp) 1608 }() 1609 req.SetRequestURI("http://127.0.0.1:10033") 1610 req.SetMethod(consts.MethodPost) 1611 err := client.Do(context.Background(), req, resp) 1612 if err != nil { 1613 t.Errorf("client Do error=%v", err.Error()) 1614 } 1615 bodyStream := resp.BodyStream() 1616 if bodyStream == nil { 1617 t.Errorf("bodystream is nil") 1618 } 1619 // Read part1 body bytes 1620 p := make([]byte, len(part1)) 1621 r, err := bodyStream.Read(p) 1622 if err != nil { 1623 t.Errorf("read from bodystream error=%v", err.Error()) 1624 } 1625 if string(p) != part1 { 1626 t.Errorf("read len=%v, read content=%v; want len=%v, want content=%v", r, string(p), len(part1), part1) 1627 } 1628 left, _ := ioutil.ReadAll(bodyStream) 1629 if string(left) != part2 { 1630 t.Errorf("left len=%v, left content=%v; want len=%v, want content=%v", len(left), string(left), len(part2), part2) 1631 } 1632 } 1633 1634 func TestWithBasicAuth(t *testing.T) { 1635 opt := config.NewOptions([]config.Option{}) 1636 opt.Addr = "127.0.0.1:10034" 1637 engine := route.NewEngine(opt) 1638 engine.GET("/", func(c context.Context, ctx *app.RequestContext) { 1639 auth := ctx.GetHeader(consts.HeaderAuthorization) 1640 if len(auth) < 6 { 1641 ctx.SetStatusCode(consts.StatusUnauthorized) 1642 return 1643 } 1644 password, err := base64.StdEncoding.DecodeString(string(auth[6:])) 1645 if err != nil || string(password) != "myuser:basicauth" { 1646 ctx.SetStatusCode(consts.StatusUnauthorized) 1647 return 1648 } 1649 }) 1650 go engine.Run() 1651 defer func() { 1652 engine.Close() 1653 }() 1654 time.Sleep(1 * time.Second) 1655 client, _ := NewClient() 1656 req := protocol.AcquireRequest() 1657 rsp := protocol.AcquireResponse() 1658 defer func() { 1659 protocol.ReleaseRequest(req) 1660 protocol.ReleaseResponse(rsp) 1661 }() 1662 1663 // Success 1664 req.SetBasicAuth("myuser", "basicauth") 1665 req.SetRequestURI("http://127.0.0.1:10034") 1666 req.SetMethod(consts.MethodGet) 1667 err := client.Do(context.Background(), req, rsp) 1668 if err != nil { 1669 t.Error(err) 1670 } 1671 if rsp.StatusCode() == consts.StatusUnauthorized { 1672 t.Error("unexpected status code=401") 1673 } 1674 1675 // Fail 1676 req.Reset() 1677 rsp.Reset() 1678 req.SetRequestURI("http://127.0.0.1:10034") 1679 req.SetMethod(consts.MethodGet) 1680 err = client.Do(context.Background(), req, rsp) 1681 if err != nil { 1682 t.Error(err) 1683 } 1684 if rsp.StatusCode() != consts.StatusUnauthorized { 1685 t.Errorf("unexpected status code: %v, expected 401", rsp.StatusCode()) 1686 } 1687 } 1688 1689 func TestClientProxyWithStandardDialer(t *testing.T) { 1690 testCases := []struct{ httpsSite, httpsProxy bool }{ 1691 {false, false}, 1692 {false, true}, 1693 {true, false}, 1694 {true, true}, 1695 } 1696 for _, testCase := range testCases { 1697 httpsSite := testCase.httpsSite 1698 httpsProxy := testCase.httpsProxy 1699 t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { 1700 siteCh := make(chan *http.Request, 1) 1701 h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1702 siteCh <- r 1703 }) 1704 proxyCh := make(chan *http.Request, 1) 1705 h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1706 proxyCh <- r 1707 if r.Method == "CONNECT" { 1708 hijacker, ok := w.(http.Hijacker) 1709 if !ok { 1710 t.Errorf("hijack not allowed") 1711 return 1712 } 1713 clientConn, _, err := hijacker.Hijack() 1714 if err != nil { 1715 t.Errorf("hijacking failed") 1716 return 1717 } 1718 res := &http.Response{ 1719 StatusCode: http.StatusOK, 1720 Proto: "HTTP/1.1", 1721 ProtoMajor: 1, 1722 ProtoMinor: 1, 1723 Header: make(http.Header), 1724 } 1725 targetConn, err := net.Dial("tcp", r.URL.Host) 1726 if err != nil { 1727 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) 1728 return 1729 } 1730 1731 if err := res.Write(clientConn); err != nil { 1732 t.Errorf("Writing 200 OK failed: %v", err) 1733 return 1734 } 1735 go io.Copy(targetConn, clientConn) 1736 go func() { 1737 io.Copy(clientConn, targetConn) 1738 targetConn.Close() 1739 }() 1740 } 1741 }) 1742 var ts *httptest.Server 1743 if httpsSite { 1744 ts = httptest.NewTLSServer(h1) 1745 } else { 1746 ts = httptest.NewServer(h1) 1747 } 1748 var proxyServer *httptest.Server 1749 if httpsProxy { 1750 proxyServer = httptest.NewTLSServer(h2) 1751 } else { 1752 proxyServer = httptest.NewServer(h2) 1753 } 1754 pu := protocol.ParseURI(proxyServer.URL) 1755 1756 // If neither server is HTTPS or both are, then c may be derived from either. 1757 // If only one server is HTTPS, c must be derived from that server in order 1758 // to ensure that it is configured to use the fake root CA from testcert.go. 1759 dialer.SetDialer(standard.NewDialer()) 1760 var cOpt config.ClientOption 1761 if httpsProxy { 1762 cOpt = WithTLSConfig(proxyServer.Client().Transport.(*http.Transport).TLSClientConfig) 1763 } else if httpsSite { 1764 cOpt = WithTLSConfig(ts.Client().Transport.(*http.Transport).TLSClientConfig) 1765 } 1766 var c *Client 1767 if httpsProxy || httpsSite { 1768 c, _ = NewClient(cOpt) 1769 } else { 1770 c, _ = NewClient() 1771 } 1772 c.SetProxy(protocol.ProxyURI(pu)) 1773 req, rsp := protocol.AcquireRequest(), protocol.AcquireResponse() 1774 defer func() { 1775 protocol.ReleaseRequest(req) 1776 protocol.ReleaseResponse(rsp) 1777 }() 1778 req.SetRequestURI(ts.URL) 1779 req.SetMethod(consts.MethodHead) 1780 err := c.Do(context.Background(), req, rsp) 1781 if err != nil { 1782 t.Error(err) 1783 } 1784 var got *http.Request 1785 select { 1786 case got = <-proxyCh: 1787 case <-time.After(5 * time.Second): 1788 t.Fatal("timeout connecting to http proxy") 1789 } 1790 ts.Close() 1791 proxyServer.Close() 1792 1793 if httpsSite { 1794 // First message should be a CONNECT to ask for a socket to the real server, 1795 if got.Method != "CONNECT" { 1796 t.Errorf("Wrong method for secure proxying: %q", got.Method) 1797 } 1798 gotHost := got.URL.Host 1799 pu, err := url.Parse(ts.URL) 1800 if err != nil { 1801 t.Fatal("Invalid site URL") 1802 } 1803 if wantHost := pu.Host; gotHost != wantHost { 1804 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) 1805 } 1806 1807 // The next message on the channel should be from the site's server. 1808 next := <-siteCh 1809 if next.Method != "HEAD" { 1810 t.Errorf("Wrong method at destination: %s", next.Method) 1811 } 1812 if nextURL := next.URL.String(); nextURL != "/" { 1813 t.Errorf("Wrong URL at destination: %s", nextURL) 1814 } 1815 } else { 1816 if got.Method != "HEAD" { 1817 t.Errorf("Wrong method for destination: %q", got.Method) 1818 } 1819 gotURL := got.URL.String() 1820 wantURL := ts.URL + "/" 1821 if gotURL != wantURL { 1822 t.Errorf("Got URL %q, want %q", gotURL, wantURL) 1823 } 1824 } 1825 }) 1826 } 1827 } 1828 1829 func TestClientProxyWithNetpollDialer(t *testing.T) { 1830 testCases := []struct{ httpsSite, httpsProxy bool }{ 1831 {false, false}, 1832 {true, false}, 1833 {false, true}, 1834 {false, true}, 1835 } 1836 for _, testCase := range testCases { 1837 httpsSite := testCase.httpsSite 1838 httpsProxy := testCase.httpsProxy 1839 t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { 1840 siteCh := make(chan *http.Request, 1) 1841 h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1842 siteCh <- r 1843 }) 1844 proxyCh := make(chan *http.Request, 1) 1845 h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1846 proxyCh <- r 1847 }) 1848 var ts *httptest.Server 1849 if httpsSite { 1850 ts = httptest.NewTLSServer(h1) 1851 } else { 1852 ts = httptest.NewServer(h1) 1853 } 1854 var proxyServer *httptest.Server 1855 if httpsProxy { 1856 proxyServer = httptest.NewTLSServer(h2) 1857 } else { 1858 proxyServer = httptest.NewServer(h2) 1859 } 1860 pu := protocol.ParseURI(proxyServer.URL) 1861 // If neither server is HTTPS or both are, then c may be derived from either. 1862 // If only one server is HTTPS, c must be derived from that server in order 1863 // to ensure that it is configured to use the fake root CA from testcert.go. 1864 1865 c, _ := NewClient() 1866 c.SetProxy(protocol.ProxyURI(pu)) 1867 req, rsp := protocol.AcquireRequest(), protocol.AcquireResponse() 1868 defer func() { 1869 protocol.ReleaseRequest(req) 1870 protocol.ReleaseResponse(rsp) 1871 }() 1872 req.SetRequestURI(ts.URL) 1873 req.SetMethod(consts.MethodHead) 1874 err := c.Do(context.Background(), req, rsp) 1875 if err != nil { 1876 t.Log(err) 1877 if !httpsSite && !httpsProxy { 1878 t.Fatal(err) 1879 } 1880 return 1881 } 1882 var got *http.Request 1883 select { 1884 case got = <-proxyCh: 1885 case <-time.After(5 * time.Second): 1886 t.Fatal("timeout connecting to http proxy") 1887 } 1888 ts.Close() 1889 proxyServer.Close() 1890 1891 if got.Method != "HEAD" { 1892 t.Errorf("Wrong method for destination: %q", got.Method) 1893 } 1894 gotURL := got.URL.String() 1895 wantURL := ts.URL + "/" 1896 if gotURL != wantURL { 1897 t.Errorf("Got URL %q, want %q", gotURL, wantURL) 1898 } 1899 }) 1900 } 1901 } 1902 1903 func TestClientMiddleware(t *testing.T) { 1904 client, _ := NewClient() 1905 mw0 := func(next Endpoint) Endpoint { 1906 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1907 req.SetRequestURI("middleware0") 1908 return next(ctx, req, resp) 1909 } 1910 } 1911 mw1 := func(next Endpoint) Endpoint { 1912 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1913 if string(req.RequestURI()) != "middleware0" { 1914 t.Errorf("Wrong request URI: %s, expected %v", req.RequestURI(), "middleware0") 1915 } 1916 req.SetRequestURI("middleware1") 1917 return next(ctx, req, resp) 1918 } 1919 } 1920 mw2 := func(next Endpoint) Endpoint { 1921 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1922 if string(req.RequestURI()) != "middleware1" { 1923 t.Errorf("Wrong request URI: %s, expected %v", req.RequestURI(), "middleware1") 1924 } 1925 return nil 1926 } 1927 } 1928 client.Use(mw0) 1929 client.Use(mw1) 1930 client.Use(mw2) 1931 1932 request, response := protocol.AcquireRequest(), protocol.AcquireResponse() 1933 defer func() { 1934 protocol.ReleaseRequest(request) 1935 protocol.ReleaseResponse(response) 1936 }() 1937 err := client.Do(context.Background(), request, response) 1938 if err != nil { 1939 t.Errorf("unexpected error: %s", err.Error()) 1940 } 1941 } 1942 1943 func TestClientLastMiddleware(t *testing.T) { 1944 client, _ := NewClient() 1945 mw0 := func(next Endpoint) Endpoint { 1946 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1947 finalValue0 := ctx.Value("final0") 1948 assert.DeepEqual(t, "final3", finalValue0) 1949 finalValue1 := ctx.Value("final1") 1950 assert.DeepEqual(t, "final1", finalValue1) 1951 finalValue2 := ctx.Value("final2") 1952 assert.DeepEqual(t, "final2", finalValue2) 1953 return nil 1954 } 1955 } 1956 mw1 := func(next Endpoint) Endpoint { 1957 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1958 ctx = context.WithValue(ctx, "final0", "final0") 1959 return next(ctx, req, resp) 1960 } 1961 } 1962 mw2 := func(next Endpoint) Endpoint { 1963 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1964 ctx = context.WithValue(ctx, "final1", "final1") 1965 return next(ctx, req, resp) 1966 } 1967 } 1968 mw3 := func(next Endpoint) Endpoint { 1969 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1970 ctx = context.WithValue(ctx, "final2", "final2") 1971 return next(ctx, req, resp) 1972 } 1973 } 1974 mw4 := func(next Endpoint) Endpoint { 1975 return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { 1976 ctx = context.WithValue(ctx, "final0", "final3") 1977 return next(ctx, req, resp) 1978 } 1979 } 1980 err := client.UseAsLast(mw0) 1981 assert.Nil(t, err) 1982 err = client.UseAsLast(func(endpoint Endpoint) Endpoint { 1983 return nil 1984 }) 1985 assert.DeepEqual(t, errorLastMiddlewareExist, err) 1986 client.Use(mw1) 1987 client.Use(mw2) 1988 client.Use(mw3) 1989 client.Use(mw4) 1990 1991 request, response := protocol.AcquireRequest(), protocol.AcquireResponse() 1992 defer func() { 1993 protocol.ReleaseRequest(request) 1994 protocol.ReleaseResponse(response) 1995 }() 1996 err = client.Do(context.Background(), request, response) 1997 if err != nil { 1998 t.Errorf("unexpected error: %s", err.Error()) 1999 } 2000 2001 last := client.TakeOutLastMiddleware() 2002 2003 assert.DeepEqual(t, reflect.ValueOf(last).Pointer(), reflect.ValueOf(mw0).Pointer()) 2004 last = client.TakeOutLastMiddleware() 2005 assert.Nil(t, last) 2006 } 2007 2008 func TestClientReadResponseBodyStreamWithDoubleRequest(t *testing.T) { 2009 part1 := "" 2010 for i := 0; i < 8192; i++ { 2011 part1 += "a" 2012 } 2013 part2 := "ghij" 2014 2015 opt := config.NewOptions([]config.Option{}) 2016 opt.Addr = "127.0.0.1:10035" 2017 engine := route.NewEngine(opt) 2018 engine.POST("/", func(ctx context.Context, c *app.RequestContext) { 2019 c.String(consts.StatusOK, part1+part2) 2020 }) 2021 go engine.Run() 2022 defer func() { 2023 engine.Close() 2024 }() 2025 time.Sleep(1 * time.Second) 2026 2027 client, _ := NewClient(WithResponseBodyStream(true)) 2028 req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() 2029 defer func() { 2030 protocol.ReleaseRequest(req) 2031 protocol.ReleaseResponse(resp) 2032 }() 2033 req.SetRequestURI("http://127.0.0.1:10035") 2034 req.SetMethod(consts.MethodPost) 2035 err := client.Do(context.Background(), req, resp) 2036 if err != nil { 2037 t.Errorf("client Do error=%v", err.Error()) 2038 } 2039 bodyStream := resp.BodyStream() 2040 if bodyStream == nil { 2041 t.Errorf("bodystream is nil") 2042 } 2043 2044 // Read part1 body bytes 2045 p := make([]byte, len(part1)) 2046 r, err := bodyStream.Read(p) 2047 if err != nil { 2048 t.Errorf("read from bodystream error=%v", err.Error()) 2049 } 2050 if string(p) != part1 { 2051 t.Errorf("read len=%v, read content=%v; want len=%v, want content=%v", r, string(p), len(part1), part1) 2052 } 2053 2054 // send another request and read all bodystream 2055 req1, resp1 := protocol.AcquireRequest(), protocol.AcquireResponse() 2056 defer func() { 2057 protocol.ReleaseRequest(req1) 2058 protocol.ReleaseResponse(resp1) 2059 }() 2060 req1.SetRequestURI("http://127.0.0.1:10035") 2061 req1.SetMethod(consts.MethodPost) 2062 err = client.Do(context.Background(), req1, resp1) 2063 if err != nil { 2064 t.Errorf("client Do error=%v", err.Error()) 2065 } 2066 bodyStream1 := resp1.BodyStream() 2067 if bodyStream1 == nil { 2068 t.Errorf("bodystream1 is nil") 2069 } 2070 data, _ := ioutil.ReadAll(bodyStream1) 2071 if string(data) != part1+part2 { 2072 t.Errorf("read len=%v, read content=%v; want len=%v, want content=%v", len(data), data, len(part1+part2), part1+part2) 2073 } 2074 2075 // read left bodystream 2076 left, _ := ioutil.ReadAll(bodyStream) 2077 if string(left) != part2 { 2078 t.Errorf("left len=%v, left content=%v; want len=%v, want content=%v", len(left), string(left), len(part2), part2) 2079 } 2080 } 2081 2082 func TestClientReadResponseBodyStreamWithConnectionClose(t *testing.T) { 2083 part1 := "" 2084 for i := 0; i < 8192; i++ { 2085 part1 += "a" 2086 } 2087 2088 opt := config.NewOptions([]config.Option{}) 2089 opt.Addr = "127.0.0.1:10036" 2090 engine := route.NewEngine(opt) 2091 engine.POST("/", func(ctx context.Context, c *app.RequestContext) { 2092 c.String(consts.StatusOK, part1) 2093 }) 2094 go engine.Run() 2095 defer func() { 2096 engine.Close() 2097 }() 2098 time.Sleep(1 * time.Second) 2099 2100 client, _ := NewClient(WithResponseBodyStream(true)) 2101 2102 // first req 2103 req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() 2104 defer func() { 2105 protocol.ReleaseRequest(req) 2106 protocol.ReleaseResponse(resp) 2107 }() 2108 req.SetConnectionClose() 2109 req.SetMethod(consts.MethodPost) 2110 req.SetRequestURI("http://127.0.0.1:10036") 2111 2112 err := client.Do(context.Background(), req, resp) 2113 if err != nil { 2114 t.Fatalf("client Do error=%v", err.Error()) 2115 } 2116 2117 assert.DeepEqual(t, part1, string(resp.Body())) 2118 2119 // second req 2120 req1, resp1 := protocol.AcquireRequest(), protocol.AcquireResponse() 2121 defer func() { 2122 protocol.ReleaseRequest(req1) 2123 protocol.ReleaseResponse(resp1) 2124 }() 2125 req1.SetConnectionClose() 2126 req1.SetMethod(consts.MethodPost) 2127 req1.SetRequestURI("http://127.0.0.1:10036") 2128 2129 err = client.Do(context.Background(), req1, resp1) 2130 if err != nil { 2131 t.Fatalf("client Do error=%v", err.Error()) 2132 } 2133 2134 assert.DeepEqual(t, part1, string(resp1.Body())) 2135 } 2136 2137 type mockDialer struct { 2138 network.Dialer 2139 customDialerFunc func(network, address string, timeout time.Duration, tlsConfig *tls.Config) 2140 network string 2141 address string 2142 timeout time.Duration 2143 } 2144 2145 func (m *mockDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { 2146 if m.customDialerFunc != nil { 2147 m.customDialerFunc(network, address, timeout, tlsConfig) 2148 } 2149 return m.Dialer.DialConnection(m.network, m.address, m.timeout, tlsConfig) 2150 } 2151 2152 func TestClientRetry(t *testing.T) { 2153 t.Parallel() 2154 client, err := NewClient( 2155 // Default dial function performs different in different os. So unit the performance of dial function. 2156 WithDialFunc(func(addr string) (network.Conn, error) { 2157 return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) 2158 }), 2159 WithRetryConfig( 2160 retry.WithMaxAttemptTimes(3), 2161 retry.WithInitDelay(100*time.Millisecond), 2162 retry.WithMaxDelay(10*time.Second), 2163 retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy)), 2164 ), 2165 ) 2166 client.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { 2167 return err != nil 2168 }) 2169 if err != nil { 2170 t.Fatal(err) 2171 return 2172 } 2173 startTime := time.Now().UnixNano() 2174 _, resp, err := client.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") 2175 if err != nil { 2176 // first delay 100+200ms , second delay 100+400ms 2177 if time.Duration(time.Now().UnixNano()-startTime) > 800*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 2*time.Second { 2178 t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2179 } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry 2180 t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2181 } else { 2182 t.Fatal(err) 2183 } 2184 } 2185 2186 client2, err := NewClient( 2187 WithDialFunc(func(addr string) (network.Conn, error) { 2188 return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) 2189 }), 2190 WithRetryConfig( 2191 retry.WithMaxAttemptTimes(2), 2192 retry.WithInitDelay(500*time.Millisecond), 2193 retry.WithMaxJitter(1*time.Second), 2194 retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy)), 2195 ), 2196 ) 2197 if err != nil { 2198 t.Fatal(err) 2199 return 2200 } 2201 client2.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { 2202 return err != nil 2203 }) 2204 startTime = time.Now().UnixNano() 2205 _, resp, err = client2.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") 2206 if err != nil { 2207 // delay max{500ms+rand([0,1))s,100ms}. Because if the MaxDelay is not set, we will use the default MaxDelay of 100ms 2208 if time.Duration(time.Now().UnixNano()-startTime) > 100*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 1100*time.Millisecond { 2209 t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2210 } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry 2211 t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2212 } else { 2213 t.Fatal(err) 2214 } 2215 } 2216 2217 client3, err := NewClient( 2218 WithDialFunc(func(addr string) (network.Conn, error) { 2219 return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) 2220 }), 2221 WithRetryConfig( 2222 retry.WithMaxAttemptTimes(2), 2223 retry.WithInitDelay(100*time.Millisecond), 2224 retry.WithMaxDelay(5*time.Second), 2225 retry.WithMaxJitter(1*time.Second), 2226 retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), 2227 ), 2228 ) 2229 if err != nil { 2230 t.Fatal(err) 2231 return 2232 } 2233 client3.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { 2234 return err != nil 2235 }) 2236 startTime = time.Now().UnixNano() 2237 _, resp, err = client3.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") 2238 if err != nil { 2239 // delay 100ms+200ms+rand([0,1))s 2240 if time.Duration(time.Now().UnixNano()-startTime) > 300*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 2300*time.Millisecond { 2241 t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2242 } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry 2243 t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2244 } else { 2245 t.Fatal(err) 2246 } 2247 } 2248 2249 client4, err := NewClient( 2250 WithDialFunc(func(addr string) (network.Conn, error) { 2251 return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) 2252 }), 2253 WithRetryConfig( 2254 retry.WithMaxAttemptTimes(2), 2255 retry.WithInitDelay(1*time.Second), 2256 retry.WithMaxDelay(10*time.Second), 2257 retry.WithMaxJitter(5*time.Second), 2258 retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), 2259 ), 2260 ) 2261 if err != nil { 2262 t.Fatal(err) 2263 return 2264 } 2265 /* If the retryIfFunc is not set , idempotent logic is used by default */ 2266 //client4.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { 2267 // return err != nil 2268 //}) 2269 startTime = time.Now().UnixNano() 2270 _, resp, err = client4.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") 2271 if err != nil { 2272 if time.Duration(time.Now().UnixNano()-startTime) > 1*time.Second && time.Duration(time.Now().UnixNano()-startTime) < 9*time.Second { 2273 t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2274 } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry 2275 t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) 2276 } else { 2277 t.Fatal(err) 2278 } 2279 return 2280 } 2281 } 2282 2283 func TestClientHostClientConfigHookError(t *testing.T) { 2284 client, _ := NewClient(WithHostClientConfigHook(func(hc interface{}) error { 2285 hct, ok := hc.(*http1.HostClient) 2286 assert.True(t, ok) 2287 assert.DeepEqual(t, "foo.bar:80", hct.Addr) 2288 return errors.New("hook return") 2289 })) 2290 2291 req := protocol.AcquireRequest() 2292 req.SetMethod(consts.MethodGet) 2293 req.SetRequestURI("http://foo.bar/") 2294 resp := protocol.AcquireResponse() 2295 err := client.do(context.TODO(), req, resp) 2296 assert.DeepEqual(t, "hook return", err.Error()) 2297 } 2298 2299 func TestClientHostClientConfigHook(t *testing.T) { 2300 client, _ := NewClient(WithHostClientConfigHook(func(hc interface{}) error { 2301 hct, ok := hc.(*http1.HostClient) 2302 assert.True(t, ok) 2303 assert.DeepEqual(t, "foo.bar:80", hct.Addr) 2304 hct.Addr = "FOO.BAR:443" 2305 return nil 2306 })) 2307 2308 req := protocol.AcquireRequest() 2309 req.SetMethod(consts.MethodGet) 2310 req.SetRequestURI("http://foo.bar/") 2311 resp := protocol.AcquireResponse() 2312 client.do(context.Background(), req, resp) 2313 client.mLock.Lock() 2314 hc := client.m["foo.bar"] 2315 client.mLock.Unlock() 2316 hcr, ok := hc.(*http1.HostClient) 2317 assert.True(t, ok) 2318 assert.DeepEqual(t, "FOO.BAR:443", hcr.Addr) 2319 } 2320 2321 func TestClientDialerName(t *testing.T) { 2322 client, _ := NewClient() 2323 dName, err := client.GetDialerName() 2324 if err != nil { 2325 t.Fatalf("unexpected error: %v", err) 2326 } 2327 // Depending on the operating system, 2328 // the default dialer has a different network library, either "netpoll" or "standard" 2329 if !(dName == "netpoll" || dName == "standard") { 2330 t.Errorf("expected 'netpoll', but get %s", dName) 2331 } 2332 2333 client, _ = NewClient(WithDialer(&mockDialer{})) 2334 dName, err = client.GetDialerName() 2335 if err != nil { 2336 t.Fatalf("unexpected error: %v", err) 2337 } 2338 if dName != "client" { 2339 t.Errorf("expected 'standard', but get %s", dName) 2340 } 2341 2342 client, _ = NewClient(WithDialer(standard.NewDialer())) 2343 dName, err = client.GetDialerName() 2344 if err != nil { 2345 t.Fatalf("unexpected error: %v", err) 2346 } 2347 if dName != "standard" { 2348 t.Errorf("expected 'standard', but get %s", dName) 2349 } 2350 2351 client, _ = NewClient(WithDialer(&mockDialer{})) 2352 dName, err = client.GetDialerName() 2353 if err != nil { 2354 t.Fatalf("unexpected error: %v", err) 2355 } 2356 if dName != "client" { 2357 t.Errorf("expected 'client', but get %s", dName) 2358 } 2359 2360 client.options.Dialer = nil 2361 dName, err = client.GetDialerName() 2362 if err == nil { 2363 t.Errorf("expected an err for abnormal process") 2364 } 2365 if dName != "" { 2366 t.Errorf("expected 'empty string', but get %s", dName) 2367 } 2368 } 2369 2370 func TestClientDoWithDialFunc(t *testing.T) { 2371 t.Parallel() 2372 2373 ch := make(chan error, 1) 2374 uri := "/foo/bar/baz" 2375 body := "request body" 2376 opt := config.NewOptions([]config.Option{}) 2377 2378 opt.Addr = "unix-test-10021" 2379 opt.Network = "unix" 2380 engine := route.NewEngine(opt) 2381 2382 engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) { 2383 if string(ctx.Request.Header.Method()) != consts.MethodPost { 2384 ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", ctx.Request.Header.Method(), consts.MethodPost) 2385 return 2386 } 2387 reqURI := ctx.Request.RequestURI() 2388 if string(reqURI) != uri { 2389 ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri) 2390 return 2391 } 2392 cl := ctx.Request.Header.ContentLength() 2393 if cl != len(body) { 2394 ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body)) 2395 return 2396 } 2397 reqBody := ctx.Request.Body() 2398 if string(reqBody) != body { 2399 ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body) 2400 return 2401 } 2402 ch <- nil 2403 }) 2404 go engine.Run() 2405 defer func() { 2406 engine.Close() 2407 }() 2408 time.Sleep(1 * time.Second) 2409 2410 c, _ := NewClient(WithDialFunc(func(addr string) (network.Conn, error) { 2411 return dialer.DialConnection(opt.Network, opt.Addr, time.Second, nil) 2412 })) 2413 2414 var req protocol.Request 2415 req.Header.SetMethod(consts.MethodPost) 2416 req.SetRequestURI(uri) 2417 req.SetHost("xxx.com") 2418 req.SetBodyString(body) 2419 2420 var resp protocol.Response 2421 2422 err := c.Do(context.Background(), &req, &resp) 2423 if err != nil { 2424 t.Fatalf("error when doing request: %s", err) 2425 } 2426 2427 select { 2428 case err = <-ch: 2429 if err != nil { 2430 t.Fatalf("err = %s", err.Error()) 2431 } 2432 case <-time.After(5 * time.Second): 2433 t.Fatalf("timeout") 2434 } 2435 } 2436 2437 func TestClientState(t *testing.T) { 2438 opt := config.NewOptions([]config.Option{}) 2439 opt.Addr = "127.0.0.1:10037" 2440 engine := route.NewEngine(opt) 2441 go engine.Run() 2442 defer func() { 2443 engine.Close() 2444 }() 2445 2446 time.Sleep(1 * time.Second) 2447 2448 state := int32(0) 2449 client, _ := NewClient( 2450 WithConnStateObserve(func(hcs config.HostClientState) { 2451 switch atomic.LoadInt32(&state) { 2452 case int32(0): 2453 assert.DeepEqual(t, 1, hcs.ConnPoolState().TotalConnNum) 2454 assert.DeepEqual(t, 1, hcs.ConnPoolState().PoolConnNum) 2455 assert.DeepEqual(t, "127.0.0.1:10037", hcs.ConnPoolState().Addr) 2456 atomic.StoreInt32(&state, int32(1)) 2457 case int32(1): 2458 assert.DeepEqual(t, 0, hcs.ConnPoolState().TotalConnNum) 2459 assert.DeepEqual(t, 0, hcs.ConnPoolState().PoolConnNum) 2460 assert.DeepEqual(t, "127.0.0.1:10037", hcs.ConnPoolState().Addr) 2461 atomic.StoreInt32(&state, int32(2)) 2462 return 2463 case int32(2): 2464 t.Fatal("It shouldn't go to here") 2465 } 2466 }, time.Second*9)) 2467 2468 client.Get(context.Background(), nil, "http://127.0.0.1:10037") 2469 time.Sleep(time.Second * 22) 2470 } 2471 2472 func TestClientRetryErr(t *testing.T) { 2473 t.Run("200", func(t *testing.T) { 2474 opt := config.NewOptions([]config.Option{}) 2475 opt.Addr = "127.0.0.1:10136" 2476 engine := route.NewEngine(opt) 2477 var l sync.Mutex 2478 retryNum := 0 2479 engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { 2480 l.Lock() 2481 defer l.Unlock() 2482 retryNum += 1 2483 ctx.SetStatusCode(200) 2484 }) 2485 go engine.Run() 2486 defer func() { 2487 engine.Close() 2488 }() 2489 time.Sleep(1 * time.Second) 2490 c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) 2491 _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10136/ping") 2492 assert.Nil(t, err) 2493 l.Lock() 2494 assert.DeepEqual(t, 1, retryNum) 2495 l.Unlock() 2496 }) 2497 2498 t.Run("502", func(t *testing.T) { 2499 opt := config.NewOptions([]config.Option{}) 2500 opt.Addr = "127.0.0.1:10137" 2501 engine := route.NewEngine(opt) 2502 var l sync.Mutex 2503 retryNum := 0 2504 engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { 2505 l.Lock() 2506 defer l.Unlock() 2507 retryNum += 1 2508 ctx.SetStatusCode(502) 2509 }) 2510 go engine.Run() 2511 defer func() { 2512 engine.Close() 2513 }() 2514 time.Sleep(1 * time.Second) 2515 c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) 2516 c.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { 2517 return resp.StatusCode() == 502 2518 }) 2519 _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10137/ping") 2520 assert.Nil(t, err) 2521 l.Lock() 2522 assert.DeepEqual(t, 3, retryNum) 2523 l.Unlock() 2524 }) 2525 }