github.com/jogo/docker@v1.7.0-rc1/pkg/transport/transport.go (about)

     1  package transport
     2  
     3  import (
     4  	"io"
     5  	"net/http"
     6  	"sync"
     7  )
     8  
     9  type RequestModifier interface {
    10  	ModifyRequest(*http.Request) error
    11  }
    12  
    13  type headerModifier http.Header
    14  
    15  // NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers
    16  // passed as an argument, with the HTTP headers of a request.
    17  //
    18  // If the same key is present in both, the modifying header values for that key,
    19  // are appended to the values for that same key in the request header.
    20  func NewHeaderRequestModifier(header http.Header) RequestModifier {
    21  	return headerModifier(header)
    22  }
    23  
    24  func (h headerModifier) ModifyRequest(req *http.Request) error {
    25  	for k, s := range http.Header(h) {
    26  		req.Header[k] = append(req.Header[k], s...)
    27  	}
    28  
    29  	return nil
    30  }
    31  
    32  // NewTransport returns an http.RoundTripper that modifies requests according to
    33  // the RequestModifiers passed in the arguments, before sending the requests to
    34  // the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport).
    35  func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper {
    36  	return &transport{
    37  		Modifiers: modifiers,
    38  		Base:      base,
    39  	}
    40  }
    41  
    42  // transport is an http.RoundTripper that makes HTTP requests after
    43  // copying and modifying the request
    44  type transport struct {
    45  	Modifiers []RequestModifier
    46  	Base      http.RoundTripper
    47  
    48  	mu     sync.Mutex                      // guards modReq
    49  	modReq map[*http.Request]*http.Request // original -> modified
    50  }
    51  
    52  func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
    53  	req2 := CloneRequest(req)
    54  	for _, modifier := range t.Modifiers {
    55  		if err := modifier.ModifyRequest(req2); err != nil {
    56  			return nil, err
    57  		}
    58  	}
    59  
    60  	t.setModReq(req, req2)
    61  	res, err := t.base().RoundTrip(req2)
    62  	if err != nil {
    63  		t.setModReq(req, nil)
    64  		return nil, err
    65  	}
    66  	res.Body = &OnEOFReader{
    67  		Rc: res.Body,
    68  		Fn: func() { t.setModReq(req, nil) },
    69  	}
    70  	return res, nil
    71  }
    72  
    73  // CancelRequest cancels an in-flight request by closing its connection.
    74  func (t *transport) CancelRequest(req *http.Request) {
    75  	type canceler interface {
    76  		CancelRequest(*http.Request)
    77  	}
    78  	if cr, ok := t.base().(canceler); ok {
    79  		t.mu.Lock()
    80  		modReq := t.modReq[req]
    81  		delete(t.modReq, req)
    82  		t.mu.Unlock()
    83  		cr.CancelRequest(modReq)
    84  	}
    85  }
    86  
    87  func (t *transport) base() http.RoundTripper {
    88  	if t.Base != nil {
    89  		return t.Base
    90  	}
    91  	return http.DefaultTransport
    92  }
    93  
    94  func (t *transport) setModReq(orig, mod *http.Request) {
    95  	t.mu.Lock()
    96  	defer t.mu.Unlock()
    97  	if t.modReq == nil {
    98  		t.modReq = make(map[*http.Request]*http.Request)
    99  	}
   100  	if mod == nil {
   101  		delete(t.modReq, orig)
   102  	} else {
   103  		t.modReq[orig] = mod
   104  	}
   105  }
   106  
   107  // CloneRequest returns a clone of the provided *http.Request.
   108  // The clone is a shallow copy of the struct and its Header map.
   109  func CloneRequest(r *http.Request) *http.Request {
   110  	// shallow copy of the struct
   111  	r2 := new(http.Request)
   112  	*r2 = *r
   113  	// deep copy of the Header
   114  	r2.Header = make(http.Header, len(r.Header))
   115  	for k, s := range r.Header {
   116  		r2.Header[k] = append([]string(nil), s...)
   117  	}
   118  
   119  	return r2
   120  }
   121  
   122  // OnEOFReader ensures a callback function is called
   123  // on Close() and when the underlying Reader returns an io.EOF error
   124  type OnEOFReader struct {
   125  	Rc io.ReadCloser
   126  	Fn func()
   127  }
   128  
   129  func (r *OnEOFReader) Read(p []byte) (n int, err error) {
   130  	n, err = r.Rc.Read(p)
   131  	if err == io.EOF {
   132  		r.runFunc()
   133  	}
   134  	return
   135  }
   136  
   137  func (r *OnEOFReader) Close() error {
   138  	err := r.Rc.Close()
   139  	r.runFunc()
   140  	return err
   141  }
   142  
   143  func (r *OnEOFReader) runFunc() {
   144  	if fn := r.Fn; fn != nil {
   145  		fn()
   146  		r.Fn = nil
   147  	}
   148  }