github.com/qiniu/x@v1.11.9/mockhttp/mockhttp.go (about)

     1  package mockhttp
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strconv"
    10  
    11  	"github.com/qiniu/x/log"
    12  )
    13  
    14  var (
    15  	ErrServerNotFound = errors.New("server not found")
    16  )
    17  
    18  // --------------------------------------------------------------------
    19  
    20  type mockServerRequestBody struct {
    21  	reader      io.Reader
    22  	closeSignal bool
    23  }
    24  
    25  func (r *mockServerRequestBody) Read(p []byte) (int, error) {
    26  	if r.closeSignal || r.reader == nil {
    27  		return 0, io.EOF
    28  	}
    29  	return r.reader.Read(p)
    30  }
    31  
    32  func (r *mockServerRequestBody) Close() error {
    33  	r.closeSignal = true
    34  	if c, ok := r.reader.(io.Closer); ok {
    35  		return c.Close()
    36  	}
    37  	return nil
    38  }
    39  
    40  // --------------------------------------------------------------------
    41  // type Transport
    42  
    43  type Transport struct {
    44  	route      map[string]http.Handler
    45  	remoteAddr string
    46  }
    47  
    48  func NewTransport() *Transport {
    49  	return &Transport{
    50  		route:      make(map[string]http.Handler),
    51  		remoteAddr: "127.0.0.1:13579",
    52  	}
    53  }
    54  
    55  func (p *Transport) SetRemoteAddr(remoteAddr string) *Transport {
    56  	p.remoteAddr = remoteAddr
    57  	return p
    58  }
    59  
    60  func (p *Transport) ListenAndServe(host string, h http.Handler) {
    61  	if h == nil {
    62  		h = http.DefaultServeMux
    63  	}
    64  	p.route[host] = h
    65  }
    66  
    67  func (p *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
    68  	h := p.route[req.URL.Host]
    69  	if h == nil {
    70  		log.Warn("Server not found:", req.Host, "-", req.URL.Host)
    71  		return nil, ErrServerNotFound
    72  	}
    73  
    74  	cp := *req
    75  	cp.RemoteAddr = p.remoteAddr
    76  	cp.Body = &mockServerRequestBody{req.Body, false}
    77  	req = &cp
    78  
    79  	rw := httptest.NewRecorder()
    80  	h.ServeHTTP(rw, req)
    81  
    82  	req.Body.Close()
    83  
    84  	ctlen := int64(-1)
    85  	if v := rw.Header().Get("Content-Length"); v != "" {
    86  		ctlen, _ = strconv.ParseInt(v, 10, 64)
    87  	}
    88  
    89  	return &http.Response{
    90  		Status:           "",
    91  		StatusCode:       rw.Code,
    92  		Header:           rw.Header(),
    93  		Body:             ioutil.NopCloser(rw.Body),
    94  		ContentLength:    ctlen,
    95  		TransferEncoding: nil,
    96  		Close:            false,
    97  		Trailer:          nil,
    98  		Request:          req,
    99  	}, nil
   100  }
   101  
   102  // --------------------------------------------------------------------
   103  
   104  var DefaultTransport = NewTransport()
   105  var DefaultClient = &http.Client{Transport: DefaultTransport}
   106  
   107  func ListenAndServe(host string, h http.Handler) {
   108  	DefaultTransport.ListenAndServe(host, h)
   109  }
   110  
   111  // --------------------------------------------------------------------