github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/docker/registry/internal/transports.go (about) 1 // Copyright 2021 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package internal 5 6 import ( 7 "encoding/json" 8 "fmt" 9 "io" 10 "net/http" 11 "net/url" 12 "strings" 13 "time" 14 15 "github.com/docker/distribution/registry/client/auth/challenge" 16 "github.com/juju/errors" 17 ) 18 19 type dynamicTransportFunc func() (http.RoundTripper, error) 20 21 // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request. 22 func (f dynamicTransportFunc) RoundTrip(req *http.Request) (*http.Response, error) { 23 transport, err := f() 24 if err != nil { 25 return nil, err 26 } 27 return transport.RoundTrip(req) 28 } 29 30 type challengeTransport struct { 31 baseTransport http.RoundTripper 32 currentTransport http.RoundTripper 33 34 username string 35 password string 36 authToken string 37 } 38 39 func newChallengeTransport( 40 transport http.RoundTripper, username string, password string, authToken string, 41 ) http.RoundTripper { 42 return &challengeTransport{ 43 baseTransport: transport, 44 username: username, 45 password: password, 46 authToken: authToken, 47 } 48 } 49 50 func (t *challengeTransport) RoundTrip(req *http.Request) (*http.Response, error) { 51 transport := t.baseTransport 52 if t.currentTransport != nil { 53 transport = t.currentTransport 54 } 55 resp, err := transport.RoundTrip(req) 56 if err != nil { 57 return nil, errors.Trace(err) 58 } 59 originalResp := resp 60 if !isUnauthorizedResponse(originalResp) { 61 return resp, nil 62 } 63 for _, c := range challenge.ResponseChallenges(originalResp) { 64 if err != nil { 65 logger.Warningf("authentication failed: %s", err.Error()) 66 err = nil 67 } 68 switch strings.ToLower(c.Scheme) { 69 case "bearer": 70 tokenTransport := &tokenTransport{ 71 transport: t.baseTransport, 72 username: t.password, 73 password: t.password, 74 authToken: t.authToken, 75 } 76 err = tokenTransport.refreshOAuthToken(originalResp) 77 if err != nil { 78 continue 79 } 80 transport = tokenTransport 81 case "basic": 82 transport = newBasicTransport(t.baseTransport, t.username, t.password, t.authToken) 83 default: 84 err = fmt.Errorf("unknown WWW-Authenticate challenge scheme: %s", c.Scheme) 85 continue 86 } 87 resp, err = transport.RoundTrip(req) 88 if err == nil && !isUnauthorizedResponse(resp) { 89 t.currentTransport = transport 90 return resp, nil 91 } 92 } 93 if err != nil { 94 return nil, errors.Trace(err) 95 } 96 if t.password == "" && t.authToken == "" { 97 return nil, errors.NewUnauthorized(err, "authorization is required for a private registry") 98 } 99 return resp, nil 100 } 101 102 type basicTransport struct { 103 transport http.RoundTripper 104 username string 105 password string 106 authToken string 107 } 108 109 func newBasicTransport( 110 transport http.RoundTripper, username string, password string, authToken string, 111 ) http.RoundTripper { 112 return &basicTransport{ 113 transport: transport, 114 username: username, 115 password: password, 116 authToken: authToken, 117 } 118 } 119 120 func (basicTransport) scheme() string { 121 return "Basic" 122 } 123 124 func (t basicTransport) authorizeRequest(req *http.Request) error { 125 if t.authToken != "" { 126 req.Header.Set("Authorization", fmt.Sprintf("%s %s", t.scheme(), t.authToken)) 127 return nil 128 } 129 if t.username != "" || t.password != "" { 130 req.SetBasicAuth(t.username, t.password) 131 } 132 return nil 133 } 134 135 // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request. 136 func (t basicTransport) RoundTrip(req *http.Request) (*http.Response, error) { 137 if err := t.authorizeRequest(req); err != nil { 138 return nil, errors.Trace(err) 139 } 140 resp, err := t.transport.RoundTrip(req) 141 logger.Tracef("basicTransport %q, resp.Header => %#v, %q", req.URL, resp.Header, resp.Status) 142 return resp, errors.Trace(err) 143 } 144 145 type tokenTransport struct { 146 transport http.RoundTripper 147 username string 148 password string 149 authToken string 150 oauthToken string 151 reuseOAuthToken bool 152 } 153 154 func newTokenTransport( 155 transport http.RoundTripper, username, password, authToken, oauthToken string, reuseOAuthToken bool, 156 ) http.RoundTripper { 157 return &tokenTransport{ 158 transport: transport, 159 username: username, 160 password: password, 161 authToken: authToken, 162 oauthToken: oauthToken, 163 reuseOAuthToken: reuseOAuthToken, 164 } 165 } 166 167 func (tokenTransport) scheme() string { 168 return "Bearer" 169 } 170 171 func getChallengeParameters(scheme string, resp *http.Response) map[string]string { 172 logger.Tracef( 173 "getting chanllenge parametter for %q with scheme %q from %q", 174 resp.Request.URL.String(), 175 scheme, resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")], 176 ) 177 for _, c := range challenge.ResponseChallenges(resp) { 178 if strings.EqualFold(c.Scheme, scheme) { 179 return c.Parameters 180 } 181 } 182 logger.Tracef("failed to get challenge parameters for %q schema -> %v", scheme, resp.Header) 183 return nil 184 } 185 186 type tokenResponse struct { 187 Token string `json:"token"` 188 AccessToken string `json:"access_token"` 189 RefreshToken string `json:"refresh_token"` 190 ExpiresIn int `json:"expires_in"` 191 IssuedAt time.Time `json:"issued_at"` 192 Scope string `json:"scope"` 193 } 194 195 func (t tokenResponse) token() string { 196 if t.AccessToken != "" { 197 return t.AccessToken 198 } 199 if t.Token != "" { 200 return t.Token 201 } 202 return "" 203 } 204 205 func (t *tokenTransport) refreshOAuthToken(failedResp *http.Response) error { 206 parameters := getChallengeParameters(t.scheme(), failedResp) 207 if len(parameters) == 0 { 208 return errors.NewForbidden(nil, "failed to refresh bearer token") 209 } 210 realm, ok := parameters["realm"] 211 if !ok { 212 return errors.New("no realm specified for token auth challenge") 213 } 214 service, ok := parameters["service"] 215 if !ok { 216 return errors.New("no service specified for token auth challenge") 217 } 218 scope, ok := parameters["scope"] 219 if !ok { 220 logger.Tracef("no scope specified for token auth challenge") 221 } 222 223 url, err := url.Parse(realm) 224 if err != nil { 225 return errors.Trace(err) 226 } 227 q := url.Query() 228 if scope != "" { 229 q.Set("scope", scope) 230 } 231 q.Set("service", service) 232 url.RawQuery = q.Encode() 233 234 request, err := http.NewRequest("GET", url.String(), nil) 235 if err != nil { 236 return errors.Trace(err) 237 } 238 tokenRefreshTransport := newBasicTransport(t.transport, t.username, t.password, t.authToken) 239 resp, err := tokenRefreshTransport.RoundTrip(request) 240 if err != nil { 241 return errors.Trace(err) 242 } 243 if resp.StatusCode != http.StatusOK { 244 _, err = handleErrorResponse(resp) 245 return errors.Trace(err) 246 } 247 248 decoder := json.NewDecoder(resp.Body) 249 var tr tokenResponse 250 if err = decoder.Decode(&tr); err != nil { 251 return fmt.Errorf("unable to decode token response: %s", err) 252 } 253 t.oauthToken = tr.token() 254 return nil 255 } 256 257 func (t *tokenTransport) authorizeRequest(req *http.Request) error { 258 if t.oauthToken != "" { 259 req.Header.Set("Authorization", fmt.Sprintf("%s %s", t.scheme(), t.oauthToken)) 260 } 261 return nil 262 } 263 264 // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request. 265 func (t *tokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { 266 defer func() { 267 if !t.reuseOAuthToken { 268 // We usually do not re-use the OAuth token because each API call might have different scope. 269 // But some of the provider use long life token and there is no need to refresh. 270 t.oauthToken = "" 271 } 272 }() 273 274 if err := t.authorizeRequest(req); err != nil { 275 return nil, errors.Trace(err) 276 } 277 resp, err := t.transport.RoundTrip(req) 278 if err != nil { 279 return nil, errors.Trace(err) 280 } 281 if isUnauthorizedResponse(resp) { 282 // refresh token and retry. 283 return t.retry(req, resp) 284 } 285 return resp, errors.Trace(err) 286 } 287 288 func (t *tokenTransport) retry(req *http.Request, prevResp *http.Response) (*http.Response, error) { 289 logger.Tracef( 290 "retrying req URL %q, previous response header %#v, status %v", 291 req.URL, prevResp.Header, prevResp.Status, 292 ) 293 294 if err := t.refreshOAuthToken(prevResp); err != nil { 295 return nil, errors.Annotatef(err, "refreshing OAuth token") 296 } 297 if err := t.authorizeRequest(req); err != nil { 298 return nil, errors.Trace(err) 299 } 300 resp, err := t.transport.RoundTrip(req) 301 if isUnauthorizedResponse(resp) { 302 if t.password == "" && t.authToken == "" { 303 return nil, errors.NewUnauthorized(err, "authorization is required for a private registry") 304 } 305 } 306 return resp, errors.Trace(err) 307 } 308 309 func isUnauthorizedResponse(resp *http.Response) bool { 310 return resp != nil && resp.StatusCode == http.StatusUnauthorized 311 } 312 313 type errorTransport struct { 314 transport http.RoundTripper 315 } 316 317 func newErrorTransport(transport http.RoundTripper) http.RoundTripper { 318 return &errorTransport{transport: transport} 319 } 320 321 // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request. 322 func (t errorTransport) RoundTrip(request *http.Request) (*http.Response, error) { 323 resp, err := t.transport.RoundTrip(request) 324 if err != nil { 325 return resp, errors.Trace(err) 326 } 327 if resp.StatusCode < 400 { 328 return resp, nil 329 } 330 logger.Tracef("errorTransport %q, err -> %v", request.URL, err) 331 return handleErrorResponse(resp) 332 } 333 334 func handleErrorResponse(resp *http.Response) (*http.Response, error) { 335 if resp.StatusCode < 400 { 336 return resp, nil 337 } 338 defer resp.Body.Close() 339 body, err := io.ReadAll(resp.Body) 340 if err != nil { 341 return nil, errors.Annotatef(err, "reading bad response body with status code %d", resp.StatusCode) 342 } 343 errMsg := fmt.Sprintf("non-successful response status=%d", resp.StatusCode) 344 if logger.IsTraceEnabled() { 345 logger.Tracef("%s, url %q, body=%q", errMsg, resp.Request.URL.String(), body) 346 } 347 errNew := errors.Errorf 348 switch resp.StatusCode { 349 case http.StatusForbidden: 350 errNew = errors.Forbiddenf 351 case http.StatusUnauthorized: 352 errNew = errors.Unauthorizedf 353 case http.StatusNotFound: 354 errNew = errors.NotFoundf 355 } 356 return nil, errNew(errMsg) 357 } 358 359 func unwrapNetError(err error) error { 360 if err == nil { 361 return nil 362 } 363 if neturlErr, ok := err.(*url.Error); ok { 364 return errors.Annotatef(neturlErr.Unwrap(), "%s %q", neturlErr.Op, neturlErr.URL) 365 } 366 return err 367 }