github.com/npaton/distribution@v2.3.1-rc.0+incompatible/registry/client/transport/transport.go (about) 1 package transport 2 3 import ( 4 "io" 5 "net/http" 6 "sync" 7 ) 8 9 // RequestModifier represents an object which will do an inplace 10 // modification of an HTTP request. 11 type RequestModifier interface { 12 ModifyRequest(*http.Request) error 13 } 14 15 type headerModifier http.Header 16 17 // NewHeaderRequestModifier returns a new RequestModifier which will 18 // add the given headers to a request. 19 func NewHeaderRequestModifier(header http.Header) RequestModifier { 20 return headerModifier(header) 21 } 22 23 func (h headerModifier) ModifyRequest(req *http.Request) error { 24 for k, s := range http.Header(h) { 25 req.Header[k] = append(req.Header[k], s...) 26 } 27 28 return nil 29 } 30 31 // NewTransport creates a new transport which will apply modifiers to 32 // the request on a RoundTrip call. 33 func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper { 34 return &transport{ 35 Modifiers: modifiers, 36 Base: base, 37 } 38 } 39 40 // transport is an http.RoundTripper that makes HTTP requests after 41 // copying and modifying the request 42 type transport struct { 43 Modifiers []RequestModifier 44 Base http.RoundTripper 45 46 mu sync.Mutex // guards modReq 47 modReq map[*http.Request]*http.Request // original -> modified 48 } 49 50 // RoundTrip authorizes and authenticates the request with an 51 // access token. If no token exists or token is expired, 52 // tries to refresh/fetch a new token. 53 func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { 54 req2 := cloneRequest(req) 55 for _, modifier := range t.Modifiers { 56 if err := modifier.ModifyRequest(req2); err != nil { 57 return nil, err 58 } 59 } 60 61 t.setModReq(req, req2) 62 res, err := t.base().RoundTrip(req2) 63 if err != nil { 64 t.setModReq(req, nil) 65 return nil, err 66 } 67 res.Body = &onEOFReader{ 68 rc: res.Body, 69 fn: func() { t.setModReq(req, nil) }, 70 } 71 return res, nil 72 } 73 74 // CancelRequest cancels an in-flight request by closing its connection. 75 func (t *transport) CancelRequest(req *http.Request) { 76 type canceler interface { 77 CancelRequest(*http.Request) 78 } 79 if cr, ok := t.base().(canceler); ok { 80 t.mu.Lock() 81 modReq := t.modReq[req] 82 delete(t.modReq, req) 83 t.mu.Unlock() 84 cr.CancelRequest(modReq) 85 } 86 } 87 88 func (t *transport) base() http.RoundTripper { 89 if t.Base != nil { 90 return t.Base 91 } 92 return http.DefaultTransport 93 } 94 95 func (t *transport) setModReq(orig, mod *http.Request) { 96 t.mu.Lock() 97 defer t.mu.Unlock() 98 if t.modReq == nil { 99 t.modReq = make(map[*http.Request]*http.Request) 100 } 101 if mod == nil { 102 delete(t.modReq, orig) 103 } else { 104 t.modReq[orig] = mod 105 } 106 } 107 108 // cloneRequest returns a clone of the provided *http.Request. 109 // The clone is a shallow copy of the struct and its Header map. 110 func cloneRequest(r *http.Request) *http.Request { 111 // shallow copy of the struct 112 r2 := new(http.Request) 113 *r2 = *r 114 // deep copy of the Header 115 r2.Header = make(http.Header, len(r.Header)) 116 for k, s := range r.Header { 117 r2.Header[k] = append([]string(nil), s...) 118 } 119 120 return r2 121 } 122 123 type onEOFReader struct { 124 rc io.ReadCloser 125 fn func() 126 } 127 128 func (r *onEOFReader) Read(p []byte) (n int, err error) { 129 n, err = r.rc.Read(p) 130 if err == io.EOF { 131 r.runFunc() 132 } 133 return 134 } 135 136 func (r *onEOFReader) Close() error { 137 err := r.rc.Close() 138 r.runFunc() 139 return err 140 } 141 142 func (r *onEOFReader) runFunc() { 143 if fn := r.fn; fn != nil { 144 fn() 145 r.fn = nil 146 } 147 }