cuelabs.dev/go/oci/ociregistry@v0.0.0-20240906074133-82eb438dd565/ociauth/auth.go (about) 1 package ociauth 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "io" 10 "net/http" 11 "net/url" 12 "slices" 13 "strings" 14 "sync" 15 "time" 16 17 "cuelabs.dev/go/oci/ociregistry" 18 ) 19 20 // TODO decide on a good value for this. 21 const oauthClientID = "cuelabs-ociauth" 22 23 var ErrNoAuth = fmt.Errorf("no authorization token available to add to request") 24 25 // stdTransport implements [http.RoundTripper] by acquiring authorization tokens 26 // using the flows implemented 27 // by the usual docker clients. Note that this is _not_ documented as 28 // part of any official OCI spec. 29 // 30 // See https://distribution.github.io/distribution/spec/auth/token/ for an overview. 31 type stdTransport struct { 32 config Config 33 transport http.RoundTripper 34 mu sync.Mutex 35 registries map[string]*registry 36 } 37 38 type StdTransportParams struct { 39 // Config represents the underlying configuration file information. 40 // It is consulted for authorization information on the hosts 41 // to which the HTTP requests are made. 42 Config Config 43 44 // HTTPClient is used to make the underlying HTTP requests. 45 // If it's nil, [http.DefaultTransport] will be used. 46 Transport http.RoundTripper 47 } 48 49 // NewStdTransport returns an [http.RoundTripper] implementation that 50 // acquires authorization tokens using the flows implemented by the 51 // usual docker clients. Note that this is _not_ documented as part of 52 // any official OCI spec. 53 // 54 // See https://distribution.github.io/distribution/spec/auth/token/ for an overview. 55 // 56 // The RoundTrip method acquires authorization before invoking the 57 // request. request. It may invoke the request more than once, and can 58 // use [http.Request.GetBody] to reset the request body if it gets 59 // consumed. 60 // 61 // It ensures that the authorization token used will have at least the 62 // capability to execute operations in the required scope associated 63 // with the request context (see [ContextWithRequestInfo]). Any other 64 // auth scope inside the context (see [ContextWithScope]) may also be 65 // taken into account when acquiring new tokens. 66 func NewStdTransport(p StdTransportParams) http.RoundTripper { 67 if p.Config == nil { 68 p.Config = emptyConfig{} 69 } 70 if p.Transport == nil { 71 p.Transport = http.DefaultTransport 72 } 73 return &stdTransport{ 74 config: p.Config, 75 transport: p.Transport, 76 registries: make(map[string]*registry), 77 } 78 } 79 80 // registry holds currently known auth information for a registry. 81 type registry struct { 82 host string 83 transport http.RoundTripper 84 config Config 85 initOnce sync.Once 86 initErr error 87 88 // mu guards the fields that follow it. 89 mu sync.Mutex 90 91 // wwwAuthenticate holds the Www-Authenticate header from 92 // the most recent 401 response. If there was a 401 response 93 // that didn't hold such a header, this will still be non-nil 94 // but hold a zero authHeader. 95 wwwAuthenticate *authHeader 96 97 accessTokens []*scopedToken 98 refreshToken string 99 basic *userPass 100 } 101 102 type scopedToken struct { 103 // scope holds the scope that the token is good for. 104 scope Scope 105 // token holds the actual access token. 106 token string 107 // expires holds when the token expires. 108 expires time.Time 109 } 110 111 type userPass struct { 112 username string 113 password string 114 } 115 116 var forever = time.Date(99999, time.January, 1, 0, 0, 0, 0, time.UTC) 117 118 // RoundTrip implements [http.RoundTripper.RoundTrip]. 119 func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) { 120 // From the [http.RoundTripper] docs: 121 // RoundTrip should not modify the request, except for 122 // consuming and closing the Request's Body. 123 req = req.Clone(req.Context()) 124 125 // From the [http.RoundTripper] docs: 126 // RoundTrip must always close the body, including on errors, [...] 127 needBodyClose := true 128 defer func() { 129 if needBodyClose && req.Body != nil { 130 req.Body.Close() 131 } 132 }() 133 134 a.mu.Lock() 135 r := a.registries[req.URL.Host] 136 if r == nil { 137 r = ®istry{ 138 host: req.URL.Host, 139 config: a.config, 140 transport: a.transport, 141 } 142 a.registries[r.host] = r 143 } 144 a.mu.Unlock() 145 if err := r.init(); err != nil { 146 return nil, err 147 } 148 149 ctx := req.Context() 150 requiredScope := RequestInfoFromContext(ctx).RequiredScope 151 wantScope := ScopeFromContext(ctx) 152 153 if err := r.setAuthorization(ctx, req, requiredScope, wantScope); err != nil { 154 return nil, err 155 } 156 resp, err := r.transport.RoundTrip(req) 157 158 // The underlying transport should now have closed the request body 159 // so we don't have to. 160 needBodyClose = false 161 if err != nil { 162 return nil, err 163 } 164 if resp.StatusCode != http.StatusUnauthorized { 165 return resp, nil 166 } 167 challenge := challengeFromResponse(resp) 168 if challenge == nil { 169 return resp, nil 170 } 171 authAdded, tokenAcquired, err := r.setAuthorizationFromChallenge(ctx, req, challenge, requiredScope, wantScope) 172 if err != nil { 173 resp.Body.Close() 174 return nil, err 175 } 176 if !authAdded { 177 // Couldn't acquire any more authorization than we had initially. 178 return resp, nil 179 } 180 resp.Body.Close() 181 // rewind request body if needed and possible. 182 if req.GetBody != nil { 183 req.Body, err = req.GetBody() 184 if err != nil { 185 return nil, err 186 } 187 } 188 resp, err = r.transport.RoundTrip(req) 189 if err != nil { 190 return nil, err 191 } 192 if resp.StatusCode != http.StatusUnauthorized || !tokenAcquired { 193 return resp, nil 194 } 195 // The server has responded with Unauthorized (401) even though we've just 196 // provided a token that it gave us. Treat it as Forbidden (403) instead. 197 // TODO include the original body/error as part of the message or message detail? 198 resp.Body.Close() 199 data, err := json.Marshal(&ociregistry.WireErrors{ 200 Errors: []ociregistry.WireError{{ 201 Code_: ociregistry.ErrDenied.Code(), 202 Message: "unauthorized response with freshly acquired auth token", 203 }}, 204 }) 205 if err != nil { 206 return nil, fmt.Errorf("cannot marshal response body: %v", err) 207 } 208 resp.Header.Set("Content-Type", "application/json") 209 resp.ContentLength = int64(len(data)) 210 resp.Body = io.NopCloser(bytes.NewReader(data)) 211 resp.StatusCode = http.StatusForbidden 212 resp.Status = http.StatusText(resp.StatusCode) 213 return resp, nil 214 } 215 216 // setAuthorization sets up authorization on the given request using any 217 // auth information currently available. 218 func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requiredScope, wantScope Scope) error { 219 r.mu.Lock() 220 defer r.mu.Unlock() 221 // Remove tokens that have expired or will expire soon so that 222 // the caller doesn't start using a token only for it to expire while it's 223 // making the request. 224 r.deleteExpiredTokens(time.Now().UTC().Add(time.Second)) 225 226 if accessToken := r.accessTokenForScope(requiredScope); accessToken != nil { 227 // We have a potentially valid access token. Use it. 228 req.Header.Set("Authorization", "Bearer "+accessToken.token) 229 return nil 230 } 231 if r.wwwAuthenticate == nil { 232 // We haven't seen a 401 response yet. Avoid putting any 233 // basic authorization in the request, because that can mean that 234 // the server sends a 401 response without a Www-Authenticate 235 // header. 236 return nil 237 } 238 if r.refreshToken != "" && r.wwwAuthenticate.scheme == "bearer" { 239 // We've got a refresh token that we can use to try to 240 // acquire an access token and we've seen a Www-Authenticate response 241 // that tells us how we can use it. 242 243 // TODO we're holding the lock (r.mu) here, which is precluding 244 // acquiring several tokens concurrently. We should relax the lock 245 // to allow that. 246 247 accessToken, err := r.acquireAccessToken(ctx, requiredScope, wantScope) 248 if err != nil { 249 // Avoid using %w to wrap the error because we don't want the 250 // caller of RoundTrip (usually ociclient) to assume that the 251 // error applies to the target server rather than the token server. 252 return fmt.Errorf("cannot acquire access token: %v", err) 253 } 254 req.Header.Set("Authorization", "Bearer "+accessToken) 255 return nil 256 } 257 if r.wwwAuthenticate.scheme != "bearer" && r.basic != nil { 258 req.SetBasicAuth(r.basic.username, r.basic.password) 259 return nil 260 } 261 return nil 262 } 263 264 func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader, requiredScope, wantScope Scope) (authAdded, tokenAcquired bool, _ error) { 265 r.mu.Lock() 266 defer r.mu.Unlock() 267 r.wwwAuthenticate = challenge 268 269 switch { 270 case r.wwwAuthenticate.scheme == "bearer": 271 scope := ParseScope(r.wwwAuthenticate.params["scope"]) 272 accessToken, err := r.acquireAccessToken(ctx, scope, wantScope.Union(requiredScope)) 273 if err != nil { 274 return false, false, err 275 } 276 req.Header.Set("Authorization", "Bearer "+accessToken) 277 return true, true, nil 278 case r.basic != nil: 279 req.SetBasicAuth(r.basic.username, r.basic.password) 280 return true, false, nil 281 } 282 return false, false, nil 283 } 284 285 // init initializes the registry instance by acquiring auth information from 286 // the Config, if available. As this might be slow (invoking EntryForRegistry 287 // can end up invoking slow external commands), we ensure that it's only 288 // done once. 289 // TODO it's possible that this could take a very long time, during which 290 // the outer context is cancelled, but we'll ignore that. We probably shouldn't. 291 func (r *registry) init() error { 292 inner := func() error { 293 info, err := r.config.EntryForRegistry(r.host) 294 if err != nil { 295 return fmt.Errorf("cannot acquire auth info for registry %q: %v", r.host, err) 296 } 297 r.refreshToken = info.RefreshToken 298 if info.AccessToken != "" { 299 r.accessTokens = append(r.accessTokens, &scopedToken{ 300 scope: UnlimitedScope(), 301 token: info.AccessToken, 302 expires: forever, 303 }) 304 } 305 if info.Username != "" && info.Password != "" { 306 r.basic = &userPass{ 307 username: info.Username, 308 password: info.Password, 309 } 310 } 311 return nil 312 } 313 r.initOnce.Do(func() { 314 r.initErr = inner() 315 }) 316 return r.initErr 317 } 318 319 // acquireAccessToken tries to acquire an access token for authorizing a request. 320 // The requiredScopeStr parameter indicates the scope that's definitely 321 // required. This is a string because apparently some servers are picky 322 // about getting exactly the same scope in the auth request that was 323 // returned in the challenge. The wantScope parameter indicates 324 // what scope might be required in the future. 325 // 326 // This method assumes that there has been a previous 401 response with 327 // a Www-Authenticate: Bearer... header. 328 func (r *registry) acquireAccessToken(ctx context.Context, requiredScope, wantScope Scope) (string, error) { 329 scope := requiredScope.Union(wantScope) 330 tok, err := r.acquireToken(ctx, scope) 331 if err != nil { 332 var herr ociregistry.HTTPError 333 if !errors.As(err, &herr) || herr.StatusCode() != http.StatusUnauthorized { 334 return "", err 335 } 336 // The documentation says this: 337 // 338 // If the client only has a subset of the requested 339 // access it _must not be considered an error_ as it is 340 // not the responsibility of the token server to 341 // indicate authorization errors as part of this 342 // workflow. 343 // 344 // However it's apparently not uncommon for servers to reject 345 // such requests anyway, so if we've got an unauthorized error 346 // and wantScope goes beyond requiredScope, it may be because 347 // the server is rejecting the request. 348 scope = requiredScope 349 tok, err = r.acquireToken(ctx, scope) 350 if err != nil { 351 return "", err 352 } 353 // TODO mark the registry as picky about tokens so we don't 354 // attempt twice every time? 355 } 356 if tok.RefreshToken != "" { 357 r.refreshToken = tok.RefreshToken 358 } 359 accessToken := tok.Token 360 if accessToken == "" { 361 accessToken = tok.AccessToken 362 } 363 if accessToken == "" { 364 return "", fmt.Errorf("no access token found in auth server response") 365 } 366 var expires time.Time 367 now := time.Now().UTC() 368 if tok.ExpiresIn == 0 { 369 expires = now.Add(60 * time.Second) // TODO link to where this is mentioned 370 } else { 371 expires = now.Add(time.Duration(tok.ExpiresIn) * time.Second) 372 } 373 r.accessTokens = append(r.accessTokens, &scopedToken{ 374 scope: scope, 375 token: accessToken, 376 expires: expires, 377 }) 378 // TODO persist the access token to save round trips when doing 379 // the authorization flow in a newly run executable. 380 return accessToken, nil 381 } 382 383 func (r *registry) acquireToken(ctx context.Context, scope Scope) (*wireToken, error) { 384 realm := r.wwwAuthenticate.params["realm"] 385 if realm == "" { 386 return nil, fmt.Errorf("malformed Www-Authenticate header (missing realm)") 387 } 388 if r.refreshToken != "" { 389 v := url.Values{} 390 v.Set("scope", scope.String()) 391 if service := r.wwwAuthenticate.params["service"]; service != "" { 392 v.Set("service", service) 393 } 394 v.Set("client_id", oauthClientID) 395 v.Set("grant_type", "refresh_token") 396 v.Set("refresh_token", r.refreshToken) 397 req, err := http.NewRequestWithContext(ctx, "POST", realm, strings.NewReader(v.Encode())) 398 if err != nil { 399 return nil, fmt.Errorf("cannot form HTTP request to %q: %v", realm, err) 400 } 401 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 402 tok, err := r.doTokenRequest(req) 403 if err == nil { 404 return tok, nil 405 } 406 var herr ociregistry.HTTPError 407 if !errors.As(err, &herr) || herr.StatusCode() != http.StatusNotFound { 408 return tok, err 409 } 410 // The request to the endpoint returned 404 from the POST request, 411 // Note: Not all token servers implement oauth2, so fall 412 // back to using a GET with basic auth. 413 // See the Token documentation for the HTTP GET method supported by all token servers. 414 // TODO where in that documentation is this documented? 415 } 416 u, err := url.Parse(realm) 417 if err != nil { 418 return nil, fmt.Errorf("malformed Www-Authenticate header (malformed realm %q): %v", realm, err) 419 } 420 v := u.Query() 421 // TODO where is it documented that we should send multiple scope 422 // attributes rather than a single space-separated attribute as 423 // the POST method does? 424 v["scope"] = strings.Split(scope.String(), " ") 425 if service := r.wwwAuthenticate.params["service"]; service != "" { 426 // TODO the containerregistry code sets this even if it's empty. 427 // Is that better? 428 v.Set("service", service) 429 } 430 u.RawQuery = v.Encode() 431 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 432 if err != nil { 433 return nil, err 434 } 435 // TODO if there's an unlimited-scope access token, the original code 436 // will use it as Bearer authorization at this point. If 437 // that's valid, why are we even acquiring another token? 438 if r.basic != nil { 439 req.SetBasicAuth(r.basic.username, r.basic.password) 440 } 441 return r.doTokenRequest(req) 442 } 443 444 // wireToken describes the JSON encoding used in the response to a token 445 // acquisition method. The comments are taken from the [token docs] 446 // and made available here for ease of reference. 447 // 448 // [token docs]: https://distribution.github.io/distribution/spec/auth/token/#token-response-fields 449 type wireToken struct { 450 // Token holds an opaque Bearer token that clients should supply 451 // to subsequent requests in the Authorization header. 452 // AccessToken is provided for compatibility with OAuth 2.0: it's equivalent to Token. 453 // At least one of these fields must be specified, but both may also appear (for compatibility with older clients). 454 // When both are specified, they should be equivalent; if they differ the client's choice is undefined. 455 Token string `json:"token"` 456 AccessToken string `json:"access_token,omitempty"` 457 458 // Refresh token optionally holds a token which can be used to 459 // get additional access tokens for the same subject with different scopes. 460 // This token should be kept secure by the client and only sent 461 // to the authorization server which issues bearer tokens. This 462 // field will only be set when `offline_token=true` is provided 463 // in the request. 464 RefreshToken string `json:"refresh_token"` 465 466 // ExpiresIn holds the duration in seconds since the token was 467 // issued that it will remain valid. When omitted, this defaults 468 // to 60 seconds. For compatibility with older clients, a token 469 // should never be returned with less than 60 seconds to live. 470 ExpiresIn int `json:"expires_in"` 471 } 472 473 func (r *registry) doTokenRequest(req *http.Request) (*wireToken, error) { 474 client := &http.Client{ 475 Transport: r.transport, 476 } 477 resp, err := client.Do(req) 478 if err != nil { 479 return nil, err 480 } 481 defer resp.Body.Close() 482 data, bodyErr := io.ReadAll(resp.Body) 483 if resp.StatusCode != http.StatusOK { 484 return nil, ociregistry.NewHTTPError(nil, resp.StatusCode, resp, data) 485 } 486 if bodyErr != nil { 487 return nil, fmt.Errorf("error reading response body: %v", err) 488 } 489 var tok wireToken 490 if err := json.Unmarshal(data, &tok); err != nil { 491 return nil, fmt.Errorf("malformed JSON token in response: %v", err) 492 } 493 return &tok, nil 494 } 495 496 // deleteExpiredTokens removes all tokens from r that expire after the given 497 // time. 498 // TODO ask the store to remove expired tokens? 499 func (r *registry) deleteExpiredTokens(now time.Time) { 500 r.accessTokens = slices.DeleteFunc(r.accessTokens, func(tok *scopedToken) bool { 501 return now.After(tok.expires) 502 }) 503 } 504 505 func (r *registry) accessTokenForScope(scope Scope) *scopedToken { 506 for _, tok := range r.accessTokens { 507 if tok.scope.Contains(scope) { 508 // TODO prefer tokens with less scope? 509 return tok 510 } 511 } 512 return nil 513 } 514 515 type emptyConfig struct{} 516 517 func (emptyConfig) EntryForRegistry(host string) (ConfigEntry, error) { 518 return ConfigEntry{}, nil 519 }