github.com/google/martian/v3@v3.3.3/martiantest/transport.go (about)

     1  package martiantest
     2  
     3  import (
     4  	"net/http"
     5  
     6  	"github.com/google/martian/v3/proxyutil"
     7  )
     8  
     9  // Transport is an http.RoundTripper for testing.
    10  type Transport struct {
    11  	rtfunc func(*http.Request) (*http.Response, error)
    12  }
    13  
    14  // NewTransport builds a new transport that will respond with a 200 OK
    15  // response.
    16  func NewTransport() *Transport {
    17  	tr := &Transport{}
    18  	tr.Respond(200)
    19  
    20  	return tr
    21  }
    22  
    23  // Respond sets the transport to respond with response with statusCode.
    24  func (tr *Transport) Respond(statusCode int) {
    25  	tr.rtfunc = func(req *http.Request) (*http.Response, error) {
    26  		// Force CONNECT requests to 200 to test CONNECT with downstream proxy.
    27  		if req.Method == "CONNECT" {
    28  			statusCode = 200
    29  		}
    30  
    31  		res := proxyutil.NewResponse(statusCode, nil, req)
    32  
    33  		return res, nil
    34  	}
    35  }
    36  
    37  // RespondError sets the transport to respond with an error on round trip.
    38  func (tr *Transport) RespondError(err error) {
    39  	tr.rtfunc = func(*http.Request) (*http.Response, error) {
    40  		return nil, err
    41  	}
    42  }
    43  
    44  // CopyHeaders sets the transport to respond with a 200 OK response with
    45  // headers copied from the request to the response verbatim.
    46  func (tr *Transport) CopyHeaders(names ...string) {
    47  	tr.rtfunc = func(req *http.Request) (*http.Response, error) {
    48  		res := proxyutil.NewResponse(200, nil, req)
    49  
    50  		for _, n := range names {
    51  			res.Header.Set(n, req.Header.Get(n))
    52  		}
    53  
    54  		return res, nil
    55  	}
    56  }
    57  
    58  // Func sets the transport to use the rtfunc.
    59  func (tr *Transport) Func(rtfunc func(*http.Request) (*http.Response, error)) {
    60  	tr.rtfunc = rtfunc
    61  }
    62  
    63  // RoundTrip runs the stored round trip func and returns the response.
    64  func (tr *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
    65  	return tr.rtfunc(req)
    66  }