github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/client_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * The MIT License (MIT)
    17   *
    18   * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors
    19   *
    20   * Permission is hereby granted, free of charge, to any person obtaining a copy
    21   * of this software and associated documentation files (the "Software"), to deal
    22   * in the Software without restriction, including without limitation the rights
    23   * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    24   * copies of the Software, and to permit persons to whom the Software is
    25   * furnished to do so, subject to the following conditions:
    26   *
    27   * The above copyright notice and this permission notice shall be included in
    28   * all copies or substantial portions of the Software.
    29   *
    30   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    31   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    32   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    33   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    34   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    35   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    36   * THE SOFTWARE.
    37   *
    38   * This file may have been modified by CloudWeGo authors. All CloudWeGo
    39   * Modifications are Copyright 2022 CloudWeGo Authors.
    40   */
    41  
    42  package http1
    43  
    44  import (
    45  	"bytes"
    46  	"context"
    47  	"crypto/tls"
    48  	"errors"
    49  	"fmt"
    50  	"io/ioutil"
    51  	"net"
    52  	"strings"
    53  	"sync"
    54  	"sync/atomic"
    55  	"testing"
    56  	"time"
    57  
    58  	"github.com/cloudwego/hertz/pkg/app/client/retry"
    59  	"github.com/cloudwego/hertz/pkg/common/config"
    60  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    61  	"github.com/cloudwego/hertz/pkg/common/hlog"
    62  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    63  	"github.com/cloudwego/hertz/pkg/common/test/mock"
    64  	"github.com/cloudwego/hertz/pkg/common/utils"
    65  	"github.com/cloudwego/hertz/pkg/network"
    66  	"github.com/cloudwego/hertz/pkg/protocol"
    67  	"github.com/cloudwego/hertz/pkg/protocol/client"
    68  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    69  	"github.com/cloudwego/hertz/pkg/protocol/http1/resp"
    70  	"github.com/cloudwego/netpoll"
    71  )
    72  
    73  var errDialTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "dial timeout")
    74  
    75  func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
    76  	var (
    77  		emptyBodyCount uint8
    78  		wg             sync.WaitGroup
    79  		// make deadline reach earlier than conns wait timeout
    80  		timeout = 10 * time.Millisecond
    81  	)
    82  
    83  	c := &HostClient{
    84  		ClientOptions: &ClientOptions{
    85  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
    86  				return mock.SlowReadDialer(addr)
    87  			}),
    88  			MaxConns:           1,
    89  			MaxConnWaitTimeout: 50 * time.Millisecond,
    90  		},
    91  		Addr: "foobar",
    92  	}
    93  
    94  	var errTimeoutCount uint32
    95  	for i := 0; i < 5; i++ {
    96  		wg.Add(1)
    97  		go func() {
    98  			defer wg.Done()
    99  
   100  			req := protocol.AcquireRequest()
   101  			req.SetRequestURI("http://foobar/baz")
   102  			req.Header.SetMethod(consts.MethodPost)
   103  			req.SetBodyString("bar")
   104  			resp := protocol.AcquireResponse()
   105  
   106  			if err := c.DoDeadline(context.Background(), req, resp, time.Now().Add(timeout)); err != nil {
   107  				if !errors.Is(err, errs.ErrTimeout) {
   108  					t.Errorf("unexpected error: %s. Expecting %s", err, errs.ErrTimeout)
   109  				}
   110  				atomic.AddUint32(&errTimeoutCount, 1)
   111  			} else {
   112  				if resp.StatusCode() != consts.StatusOK {
   113  					t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK)
   114  				}
   115  
   116  				body := resp.Body()
   117  				if string(body) != "foo" {
   118  					t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
   119  				}
   120  			}
   121  		}()
   122  	}
   123  	wg.Wait()
   124  
   125  	c.connsLock.Lock()
   126  	for {
   127  		w := c.connsWait.popFront()
   128  		if w == nil {
   129  			break
   130  		}
   131  		w.mu.Lock()
   132  		if w.err != nil && !errors.Is(w.err, errs.ErrNoFreeConns) {
   133  			t.Errorf("unexpected error: %s. Expecting %s", w.err, errs.ErrNoFreeConns)
   134  		}
   135  		w.mu.Unlock()
   136  	}
   137  	c.connsLock.Unlock()
   138  	if errTimeoutCount == 0 {
   139  		t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", errTimeoutCount)
   140  	}
   141  
   142  	if emptyBodyCount > 0 {
   143  		t.Fatalf("at least one request body was empty")
   144  	}
   145  }
   146  
   147  func TestResponseReadBodyStream(t *testing.T) {
   148  	// small body
   149  	genBody := "abcdef4343"
   150  	s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 5\r\n\r\n"
   151  	testContinueReadResponseBodyStream(t, s, genBody, 10, 5, 0, 5)
   152  	testContinueReadResponseBodyStream(t, s, genBody, 1, 5, 0, 0)
   153  
   154  	// big body (> 8193)
   155  	s1 := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 9216\r\nContent-Type: foo/bar\r\n\r\n"
   156  	genBody = strings.Repeat("1", 9*1024)
   157  	testContinueReadResponseBodyStream(t, s1, genBody, 10*1024, 5*1024, 4*1024, 0)
   158  	testContinueReadResponseBodyStream(t, s1, genBody, 10*1024, 1*1024, 8*1024, 0)
   159  	testContinueReadResponseBodyStream(t, s1, genBody, 10*1024, 9*1024, 0*1024, 0)
   160  
   161  	// normal stream
   162  	testContinueReadResponseBodyStream(t, s1, genBody, 1*1024, 5*1024, 4*1024, 0)
   163  	testContinueReadResponseBodyStream(t, s1, genBody, 1*1024, 1*1024, 8*1024, 0)
   164  	testContinueReadResponseBodyStream(t, s1, genBody, 1*1024, 9*1024, 0*1024, 0)
   165  	testContinueReadResponseBodyStream(t, s1, genBody, 5, 5*1024, 4*1024, 0)
   166  	testContinueReadResponseBodyStream(t, s1, genBody, 5, 1*1024, 8*1024, 0)
   167  	testContinueReadResponseBodyStream(t, s1, genBody, 5, 9*1024, 0, 0)
   168  
   169  	// critical point
   170  	testContinueReadResponseBodyStream(t, s1, genBody, 8*1024+1, 5*1024, 4*1024, 0)
   171  	testContinueReadResponseBodyStream(t, s1, genBody, 8*1024+1, 1*1024, 8*1024, 0)
   172  	testContinueReadResponseBodyStream(t, s1, genBody, 8*1024+1, 9*1024, 0*1024, 0)
   173  
   174  	// chunked body
   175  	s2 := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\ntrail"
   176  	testContinueReadResponseBodyStream(t, s2, "", 10*1024, 3, 5, 5)
   177  	s3 := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\n"
   178  	testContinueReadResponseBodyStream(t, s3, "", 10*1024, 3, 5, 0)
   179  }
   180  
   181  func testContinueReadResponseBodyStream(t *testing.T, header, body string, maxBodySize, firstRead, leftBytes, bytesLeftInReader int) {
   182  	mr := netpoll.NewReader(bytes.NewBufferString(header + body))
   183  	var r protocol.Response
   184  	if err := resp.ReadBodyStream(&r, mr, maxBodySize, nil); err != nil {
   185  		t.Fatalf("error when reading request body stream: %s", err)
   186  	}
   187  	fRead := firstRead
   188  	streamRead := make([]byte, fRead)
   189  	sR, _ := r.BodyStream().Read(streamRead)
   190  
   191  	if sR != firstRead {
   192  		t.Fatalf("should read %d from stream body, but got %d", firstRead, sR)
   193  	}
   194  
   195  	leftB, _ := ioutil.ReadAll(r.BodyStream())
   196  	if len(leftB) != leftBytes {
   197  		t.Fatalf("should left %d bytes from stream body, but left %d", leftBytes, len(leftB))
   198  	}
   199  	if r.Header.ContentLength() > 0 {
   200  		gotBody := append(streamRead, leftB...)
   201  		if !bytes.Equal([]byte(body[:r.Header.ContentLength()]), gotBody) {
   202  			t.Fatalf("body read from stream is not equal to the origin. Got: %s", gotBody)
   203  		}
   204  	}
   205  
   206  	left, _ := mr.Next(mr.Len())
   207  
   208  	if len(left) != bytesLeftInReader {
   209  		fmt.Printf("##########header:%s,body:%s,%d:max,first:%d,left:%d,leftin:%d\n", header, body, maxBodySize, firstRead, leftBytes, bytesLeftInReader)
   210  		fmt.Printf("##########left: %s\n", left)
   211  		t.Fatalf("should left %d bytes in original reader. got %q", bytesLeftInReader, len(left))
   212  	}
   213  }
   214  
   215  func newSlowConnDialer(dialer func(network, addr string, timeout time.Duration) (network.Conn, error)) network.Dialer {
   216  	return &mockDialer{customDialConn: dialer}
   217  }
   218  
   219  type mockDialer struct {
   220  	customDialConn func(network, addr string, timeout time.Duration) (network.Conn, error)
   221  }
   222  
   223  func (m *mockDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) {
   224  	return m.customDialConn(network, address, timeout)
   225  }
   226  
   227  func (m *mockDialer) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) {
   228  	return nil, nil
   229  }
   230  
   231  func (m *mockDialer) AddTLS(conn network.Conn, tlsConfig *tls.Config) (network.Conn, error) {
   232  	return nil, nil
   233  }
   234  
   235  type slowDialer struct {
   236  	*mockDialer
   237  }
   238  
   239  func (s *slowDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) {
   240  	time.Sleep(timeout)
   241  	return nil, errDialTimeout
   242  }
   243  
   244  func TestReadTimeoutPriority(t *testing.T) {
   245  	c := &HostClient{
   246  		ClientOptions: &ClientOptions{
   247  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   248  				return mock.SlowReadDialer(addr)
   249  			}),
   250  			MaxConns:           1,
   251  			MaxConnWaitTimeout: 50 * time.Millisecond,
   252  			ReadTimeout:        time.Second * 3,
   253  		},
   254  		Addr: "foobar",
   255  	}
   256  
   257  	req := protocol.AcquireRequest()
   258  	req.SetRequestURI("http://foobar/baz")
   259  	req.SetOptions(config.WithReadTimeout(time.Second * 1))
   260  	resp := protocol.AcquireResponse()
   261  
   262  	ch := make(chan error, 1)
   263  	go func() {
   264  		ch <- c.Do(context.Background(), req, resp)
   265  	}()
   266  	select {
   267  	case <-time.After(time.Second * 2):
   268  		t.Fatalf("should use readTimeout in request options")
   269  	case err := <-ch:
   270  		assert.DeepEqual(t, mock.ErrReadTimeout, err)
   271  	}
   272  }
   273  
   274  func TestDoNonNilReqResp(t *testing.T) {
   275  	c := &HostClient{
   276  		ClientOptions: &ClientOptions{
   277  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   278  				return &writeErrConn{
   279  						Conn: mock.NewConn("HTTP/1.1 400 OK\nContent-Length: 6\n\n123456"),
   280  					},
   281  					nil
   282  			}),
   283  		},
   284  	}
   285  	req := protocol.AcquireRequest()
   286  	resp := protocol.AcquireResponse()
   287  	req.SetHost("foobar")
   288  	retry, err := c.doNonNilReqResp(req, resp)
   289  	assert.False(t, retry)
   290  	assert.Nil(t, err)
   291  	assert.DeepEqual(t, resp.StatusCode(), 400)
   292  	assert.DeepEqual(t, resp.Body(), []byte("123456"))
   293  }
   294  
   295  func TestDoNonNilReqResp1(t *testing.T) {
   296  	c := &HostClient{
   297  		ClientOptions: &ClientOptions{
   298  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   299  				return &writeErrConn{
   300  						Conn: mock.NewConn(""),
   301  					},
   302  					nil
   303  			}),
   304  		},
   305  	}
   306  	req := protocol.AcquireRequest()
   307  	resp := protocol.AcquireResponse()
   308  	req.SetHost("foobar")
   309  	retry, err := c.doNonNilReqResp(req, resp)
   310  	assert.True(t, retry)
   311  	assert.NotNil(t, err)
   312  }
   313  
   314  func TestWriteTimeoutPriority(t *testing.T) {
   315  	c := &HostClient{
   316  		ClientOptions: &ClientOptions{
   317  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   318  				return mock.SlowWriteDialer(addr)
   319  			}),
   320  			MaxConns:           1,
   321  			MaxConnWaitTimeout: 50 * time.Millisecond,
   322  			WriteTimeout:       time.Second * 3,
   323  		},
   324  		Addr: "foobar",
   325  	}
   326  
   327  	req := protocol.AcquireRequest()
   328  	req.SetRequestURI("http://foobar/baz")
   329  	req.SetOptions(config.WithWriteTimeout(time.Second * 1))
   330  	resp := protocol.AcquireResponse()
   331  
   332  	ch := make(chan error, 1)
   333  	go func() {
   334  		ch <- c.Do(context.Background(), req, resp)
   335  	}()
   336  	select {
   337  	case <-time.After(time.Second * 2):
   338  		t.Fatalf("should use writeTimeout in request options")
   339  	case err := <-ch:
   340  		assert.DeepEqual(t, mock.ErrWriteTimeout, err)
   341  	}
   342  }
   343  
   344  func TestDialTimeoutPriority(t *testing.T) {
   345  	c := &HostClient{
   346  		ClientOptions: &ClientOptions{
   347  			Dialer:             &slowDialer{},
   348  			MaxConns:           1,
   349  			MaxConnWaitTimeout: 50 * time.Millisecond,
   350  			DialTimeout:        time.Second * 3,
   351  		},
   352  		Addr: "foobar",
   353  	}
   354  
   355  	req := protocol.AcquireRequest()
   356  	req.SetRequestURI("http://foobar/baz")
   357  	req.SetOptions(config.WithDialTimeout(time.Second * 1))
   358  	resp := protocol.AcquireResponse()
   359  
   360  	ch := make(chan error, 1)
   361  	go func() {
   362  		ch <- c.Do(context.Background(), req, resp)
   363  	}()
   364  	select {
   365  	case <-time.After(time.Second * 2):
   366  		t.Fatalf("should use dialTimeout in request options")
   367  	case err := <-ch:
   368  		assert.DeepEqual(t, errDialTimeout, err)
   369  	}
   370  }
   371  
   372  func TestStateObserve(t *testing.T) {
   373  	syncState := struct {
   374  		mu    sync.Mutex
   375  		state config.ConnPoolState
   376  	}{}
   377  	c := &HostClient{
   378  		ClientOptions: &ClientOptions{
   379  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   380  				return mock.SlowReadDialer(addr)
   381  			}),
   382  			StateObserve: func(hcs config.HostClientState) {
   383  				syncState.mu.Lock()
   384  				defer syncState.mu.Unlock()
   385  				syncState.state = hcs.ConnPoolState()
   386  			},
   387  			ObservationInterval: 50 * time.Millisecond,
   388  		},
   389  		Addr:   "foobar",
   390  		closed: make(chan struct{}),
   391  	}
   392  
   393  	c.SetDynamicConfig(&client.DynamicConfig{
   394  		Addr: utils.AddMissingPort(c.Addr, true),
   395  	})
   396  
   397  	time.Sleep(500 * time.Millisecond)
   398  	assert.Nil(t, c.Close())
   399  	syncState.mu.Lock()
   400  	assert.DeepEqual(t, "foobar:443", syncState.state.Addr)
   401  	syncState.mu.Unlock()
   402  }
   403  
   404  func TestCachedTLSConfig(t *testing.T) {
   405  	c := &HostClient{
   406  		ClientOptions: &ClientOptions{
   407  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   408  				return mock.SlowReadDialer(addr)
   409  			}),
   410  			TLSConfig: &tls.Config{
   411  				InsecureSkipVerify: true,
   412  			},
   413  		},
   414  		Addr:  "foobar",
   415  		IsTLS: true,
   416  	}
   417  
   418  	cfg1 := c.cachedTLSConfig("foobar")
   419  	cfg2 := c.cachedTLSConfig("baz")
   420  	assert.NotEqual(t, cfg1, cfg2)
   421  	cfg3 := c.cachedTLSConfig("foobar")
   422  	assert.DeepEqual(t, cfg1, cfg3)
   423  }
   424  
   425  func TestRetry(t *testing.T) {
   426  	var times int32
   427  	c := &HostClient{
   428  		ClientOptions: &ClientOptions{
   429  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   430  				times++
   431  				if times < 3 {
   432  					return &retryConn{
   433  						Conn: mock.NewConn(""),
   434  					}, nil
   435  				}
   436  				return mock.NewConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil
   437  			}),
   438  			RetryConfig: &retry.Config{
   439  				MaxAttemptTimes: 5,
   440  				Delay:           time.Millisecond * 10,
   441  			},
   442  			RetryIfFunc: func(req *protocol.Request, resp *protocol.Response, err error) bool {
   443  				return resp.Header.ContentLength() != 10
   444  			},
   445  		},
   446  		Addr: "foobar",
   447  	}
   448  
   449  	req := protocol.AcquireRequest()
   450  	req.SetRequestURI("http://foobar/baz")
   451  	req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100))
   452  	resp := protocol.AcquireResponse()
   453  
   454  	ch := make(chan error, 1)
   455  	go func() {
   456  		ch <- c.Do(context.Background(), req, resp)
   457  	}()
   458  	select {
   459  	case <-time.After(time.Second * 2):
   460  		t.Fatalf("should use writeTimeout in request options")
   461  	case err := <-ch:
   462  		assert.Nil(t, err)
   463  		assert.True(t, times == 3)
   464  		assert.DeepEqual(t, resp.StatusCode(), 200)
   465  		assert.DeepEqual(t, resp.Body(), []byte("0123456789"))
   466  	}
   467  }
   468  
   469  // mockConn for getting error when write binary data.
   470  type writeErrConn struct {
   471  	network.Conn
   472  }
   473  
   474  func (w writeErrConn) WriteBinary(b []byte) (n int, err error) {
   475  	return 0, errs.ErrConnectionClosed
   476  }
   477  
   478  type retryConn struct {
   479  	network.Conn
   480  }
   481  
   482  func (w retryConn) SetWriteTimeout(t time.Duration) error {
   483  	return errors.New("should retry")
   484  }
   485  
   486  func TestConnInPoolRetry(t *testing.T) {
   487  	c := &HostClient{
   488  		ClientOptions: &ClientOptions{
   489  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   490  				return mock.NewOneTimeConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil
   491  			}),
   492  		},
   493  		Addr: "foobar",
   494  	}
   495  
   496  	req := protocol.AcquireRequest()
   497  	req.SetRequestURI("http://foobar/baz")
   498  	req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100))
   499  	resp := protocol.AcquireResponse()
   500  
   501  	logbuf := &bytes.Buffer{}
   502  	hlog.SetOutput(logbuf)
   503  
   504  	err := c.Do(context.Background(), req, resp)
   505  	assert.Nil(t, err)
   506  	assert.DeepEqual(t, resp.StatusCode(), 200)
   507  	assert.DeepEqual(t, string(resp.Body()), "0123456789")
   508  	assert.True(t, logbuf.String() == "")
   509  	protocol.ReleaseResponse(resp)
   510  	resp = protocol.AcquireResponse()
   511  	err = c.Do(context.Background(), req, resp)
   512  	assert.Nil(t, err)
   513  	assert.DeepEqual(t, resp.StatusCode(), 200)
   514  	assert.DeepEqual(t, string(resp.Body()), "0123456789")
   515  	assert.True(t, strings.Contains(logbuf.String(), "Client connection attempt times: 1"))
   516  }
   517  
   518  func TestConnNotRetry(t *testing.T) {
   519  	c := &HostClient{
   520  		ClientOptions: &ClientOptions{
   521  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   522  				return mock.NewBrokenConn(""), nil
   523  			}),
   524  		},
   525  		Addr: "foobar",
   526  	}
   527  
   528  	req := protocol.AcquireRequest()
   529  	req.SetRequestURI("http://foobar/baz")
   530  	req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100))
   531  	resp := protocol.AcquireResponse()
   532  	logbuf := &bytes.Buffer{}
   533  	hlog.SetOutput(logbuf)
   534  	err := c.Do(context.Background(), req, resp)
   535  	assert.DeepEqual(t, errs.ErrConnectionClosed, err)
   536  	assert.True(t, logbuf.String() == "")
   537  	protocol.ReleaseResponse(resp)
   538  }
   539  
   540  type countCloseConn struct {
   541  	network.Conn
   542  	isClose bool
   543  }
   544  
   545  func (c *countCloseConn) Close() error {
   546  	c.isClose = true
   547  	return nil
   548  }
   549  
   550  func newCountCloseConn(s string) *countCloseConn {
   551  	return &countCloseConn{
   552  		Conn: mock.NewConn(s),
   553  	}
   554  }
   555  
   556  func TestStreamNoContent(t *testing.T) {
   557  	conn := newCountCloseConn("HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\nHTTP/1.2")
   558  
   559  	c := &HostClient{
   560  		ClientOptions: &ClientOptions{
   561  			Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
   562  				return conn, nil
   563  			}),
   564  		},
   565  		Addr: "foobar",
   566  	}
   567  
   568  	c.ResponseBodyStream = true
   569  
   570  	req := protocol.AcquireRequest()
   571  	req.SetRequestURI("http://foobar/baz")
   572  	req.Header.SetConnectionClose(true)
   573  	resp := protocol.AcquireResponse()
   574  
   575  	c.Do(context.Background(), req, resp)
   576  
   577  	assert.True(t, conn.isClose)
   578  }
   579  
   580  func TestDialTimeout(t *testing.T) {
   581  	c := &HostClient{
   582  		ClientOptions: &ClientOptions{
   583  			DialTimeout: time.Second * 10,
   584  			Dialer: &mockDialer{
   585  				customDialConn: func(network, addr string, timeout time.Duration) (network.Conn, error) {
   586  					assert.DeepEqual(t, time.Second*10, timeout)
   587  					return nil, errors.New("test error")
   588  				},
   589  			},
   590  		},
   591  		Addr: "foobar",
   592  	}
   593  
   594  	req := protocol.AcquireRequest()
   595  	req.SetRequestURI("http://foobar/baz")
   596  	resp := protocol.AcquireResponse()
   597  
   598  	c.Do(context.Background(), req, resp)
   599  }