cuelabs.dev/go/oci/ociregistry@v0.0.0-20240906074133-82eb438dd565/ociauth/auth.go (about)

     1  package ociauth
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"slices"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"cuelabs.dev/go/oci/ociregistry"
    18  )
    19  
    20  // TODO decide on a good value for this.
    21  const oauthClientID = "cuelabs-ociauth"
    22  
    23  var ErrNoAuth = fmt.Errorf("no authorization token available to add to request")
    24  
    25  // stdTransport implements [http.RoundTripper] by acquiring authorization tokens
    26  // using the flows implemented
    27  // by the usual docker clients. Note that this is _not_ documented as
    28  // part of any official OCI spec.
    29  //
    30  // See https://distribution.github.io/distribution/spec/auth/token/ for an overview.
    31  type stdTransport struct {
    32  	config     Config
    33  	transport  http.RoundTripper
    34  	mu         sync.Mutex
    35  	registries map[string]*registry
    36  }
    37  
    38  type StdTransportParams struct {
    39  	// Config represents the underlying configuration file information.
    40  	// It is consulted for authorization information on the hosts
    41  	// to which the HTTP requests are made.
    42  	Config Config
    43  
    44  	// HTTPClient is used to make the underlying HTTP requests.
    45  	// If it's nil, [http.DefaultTransport] will be used.
    46  	Transport http.RoundTripper
    47  }
    48  
    49  // NewStdTransport returns an [http.RoundTripper] implementation that
    50  // acquires authorization tokens using the flows implemented by the
    51  // usual docker clients. Note that this is _not_ documented as part of
    52  // any official OCI spec.
    53  //
    54  // See https://distribution.github.io/distribution/spec/auth/token/ for an overview.
    55  //
    56  // The RoundTrip method acquires authorization before invoking the
    57  // request. request. It may invoke the request more than once, and can
    58  // use [http.Request.GetBody] to reset the request body if it gets
    59  // consumed.
    60  //
    61  // It ensures that the authorization token used will have at least the
    62  // capability to execute operations in the required scope associated
    63  // with the request context (see [ContextWithRequestInfo]). Any other
    64  // auth scope inside the context (see [ContextWithScope]) may also be
    65  // taken into account when acquiring new tokens.
    66  func NewStdTransport(p StdTransportParams) http.RoundTripper {
    67  	if p.Config == nil {
    68  		p.Config = emptyConfig{}
    69  	}
    70  	if p.Transport == nil {
    71  		p.Transport = http.DefaultTransport
    72  	}
    73  	return &stdTransport{
    74  		config:     p.Config,
    75  		transport:  p.Transport,
    76  		registries: make(map[string]*registry),
    77  	}
    78  }
    79  
    80  // registry holds currently known auth information for a registry.
    81  type registry struct {
    82  	host      string
    83  	transport http.RoundTripper
    84  	config    Config
    85  	initOnce  sync.Once
    86  	initErr   error
    87  
    88  	// mu guards the fields that follow it.
    89  	mu sync.Mutex
    90  
    91  	// wwwAuthenticate holds the Www-Authenticate header from
    92  	// the most recent 401 response. If there was a 401 response
    93  	// that didn't hold such a header, this will still be non-nil
    94  	// but hold a zero authHeader.
    95  	wwwAuthenticate *authHeader
    96  
    97  	accessTokens []*scopedToken
    98  	refreshToken string
    99  	basic        *userPass
   100  }
   101  
   102  type scopedToken struct {
   103  	// scope holds the scope that the token is good for.
   104  	scope Scope
   105  	// token holds the actual access token.
   106  	token string
   107  	// expires holds when the token expires.
   108  	expires time.Time
   109  }
   110  
   111  type userPass struct {
   112  	username string
   113  	password string
   114  }
   115  
   116  var forever = time.Date(99999, time.January, 1, 0, 0, 0, 0, time.UTC)
   117  
   118  // RoundTrip implements [http.RoundTripper.RoundTrip].
   119  func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   120  	// From the [http.RoundTripper] docs:
   121  	//	RoundTrip should not modify the request, except for
   122  	//	consuming and closing the Request's Body.
   123  	req = req.Clone(req.Context())
   124  
   125  	// From the [http.RoundTripper] docs:
   126  	//	RoundTrip must always close the body, including on errors, [...]
   127  	needBodyClose := true
   128  	defer func() {
   129  		if needBodyClose && req.Body != nil {
   130  			req.Body.Close()
   131  		}
   132  	}()
   133  
   134  	a.mu.Lock()
   135  	r := a.registries[req.URL.Host]
   136  	if r == nil {
   137  		r = &registry{
   138  			host:      req.URL.Host,
   139  			config:    a.config,
   140  			transport: a.transport,
   141  		}
   142  		a.registries[r.host] = r
   143  	}
   144  	a.mu.Unlock()
   145  	if err := r.init(); err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	ctx := req.Context()
   150  	requiredScope := RequestInfoFromContext(ctx).RequiredScope
   151  	wantScope := ScopeFromContext(ctx)
   152  
   153  	if err := r.setAuthorization(ctx, req, requiredScope, wantScope); err != nil {
   154  		return nil, err
   155  	}
   156  	resp, err := r.transport.RoundTrip(req)
   157  
   158  	// The underlying transport should now have closed the request body
   159  	// so we don't have to.
   160  	needBodyClose = false
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	if resp.StatusCode != http.StatusUnauthorized {
   165  		return resp, nil
   166  	}
   167  	challenge := challengeFromResponse(resp)
   168  	if challenge == nil {
   169  		return resp, nil
   170  	}
   171  	authAdded, tokenAcquired, err := r.setAuthorizationFromChallenge(ctx, req, challenge, requiredScope, wantScope)
   172  	if err != nil {
   173  		resp.Body.Close()
   174  		return nil, err
   175  	}
   176  	if !authAdded {
   177  		// Couldn't acquire any more authorization than we had initially.
   178  		return resp, nil
   179  	}
   180  	resp.Body.Close()
   181  	// rewind request body if needed and possible.
   182  	if req.GetBody != nil {
   183  		req.Body, err = req.GetBody()
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  	}
   188  	resp, err = r.transport.RoundTrip(req)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  	if resp.StatusCode != http.StatusUnauthorized || !tokenAcquired {
   193  		return resp, nil
   194  	}
   195  	// The server has responded with Unauthorized (401) even though we've just
   196  	// provided a token that it gave us. Treat it as Forbidden (403) instead.
   197  	// TODO include the original body/error as part of the message or message detail?
   198  	resp.Body.Close()
   199  	data, err := json.Marshal(&ociregistry.WireErrors{
   200  		Errors: []ociregistry.WireError{{
   201  			Code_:   ociregistry.ErrDenied.Code(),
   202  			Message: "unauthorized response with freshly acquired auth token",
   203  		}},
   204  	})
   205  	if err != nil {
   206  		return nil, fmt.Errorf("cannot marshal response body: %v", err)
   207  	}
   208  	resp.Header.Set("Content-Type", "application/json")
   209  	resp.ContentLength = int64(len(data))
   210  	resp.Body = io.NopCloser(bytes.NewReader(data))
   211  	resp.StatusCode = http.StatusForbidden
   212  	resp.Status = http.StatusText(resp.StatusCode)
   213  	return resp, nil
   214  }
   215  
   216  // setAuthorization sets up authorization on the given request using any
   217  // auth information currently available.
   218  func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requiredScope, wantScope Scope) error {
   219  	r.mu.Lock()
   220  	defer r.mu.Unlock()
   221  	// Remove tokens that have expired or will expire soon so that
   222  	// the caller doesn't start using a token only for it to expire while it's
   223  	// making the request.
   224  	r.deleteExpiredTokens(time.Now().UTC().Add(time.Second))
   225  
   226  	if accessToken := r.accessTokenForScope(requiredScope); accessToken != nil {
   227  		// We have a potentially valid access token. Use it.
   228  		req.Header.Set("Authorization", "Bearer "+accessToken.token)
   229  		return nil
   230  	}
   231  	if r.wwwAuthenticate == nil {
   232  		// We haven't seen a 401 response yet. Avoid putting any
   233  		// basic authorization in the request, because that can mean that
   234  		// the server sends a 401 response without a Www-Authenticate
   235  		// header.
   236  		return nil
   237  	}
   238  	if r.refreshToken != "" && r.wwwAuthenticate.scheme == "bearer" {
   239  		// We've got a refresh token that we can use to try to
   240  		// acquire an access token and we've seen a Www-Authenticate response
   241  		// that tells us how we can use it.
   242  
   243  		// TODO we're holding the lock (r.mu) here, which is precluding
   244  		// acquiring several tokens concurrently. We should relax the lock
   245  		// to allow that.
   246  
   247  		accessToken, err := r.acquireAccessToken(ctx, requiredScope, wantScope)
   248  		if err != nil {
   249  			// Avoid using %w to wrap the error because we don't want the
   250  			// caller of RoundTrip (usually ociclient) to assume that the
   251  			// error applies to the target server rather than the token server.
   252  			return fmt.Errorf("cannot acquire access token: %v", err)
   253  		}
   254  		req.Header.Set("Authorization", "Bearer "+accessToken)
   255  		return nil
   256  	}
   257  	if r.wwwAuthenticate.scheme != "bearer" && r.basic != nil {
   258  		req.SetBasicAuth(r.basic.username, r.basic.password)
   259  		return nil
   260  	}
   261  	return nil
   262  }
   263  
   264  func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader, requiredScope, wantScope Scope) (authAdded, tokenAcquired bool, _ error) {
   265  	r.mu.Lock()
   266  	defer r.mu.Unlock()
   267  	r.wwwAuthenticate = challenge
   268  
   269  	switch {
   270  	case r.wwwAuthenticate.scheme == "bearer":
   271  		scope := ParseScope(r.wwwAuthenticate.params["scope"])
   272  		accessToken, err := r.acquireAccessToken(ctx, scope, wantScope.Union(requiredScope))
   273  		if err != nil {
   274  			return false, false, err
   275  		}
   276  		req.Header.Set("Authorization", "Bearer "+accessToken)
   277  		return true, true, nil
   278  	case r.basic != nil:
   279  		req.SetBasicAuth(r.basic.username, r.basic.password)
   280  		return true, false, nil
   281  	}
   282  	return false, false, nil
   283  }
   284  
   285  // init initializes the registry instance by acquiring auth information from
   286  // the Config, if available. As this might be slow (invoking EntryForRegistry
   287  // can end up invoking slow external commands), we ensure that it's only
   288  // done once.
   289  // TODO it's possible that this could take a very long time, during which
   290  // the outer context is cancelled, but we'll ignore that. We probably shouldn't.
   291  func (r *registry) init() error {
   292  	inner := func() error {
   293  		info, err := r.config.EntryForRegistry(r.host)
   294  		if err != nil {
   295  			return fmt.Errorf("cannot acquire auth info for registry %q: %v", r.host, err)
   296  		}
   297  		r.refreshToken = info.RefreshToken
   298  		if info.AccessToken != "" {
   299  			r.accessTokens = append(r.accessTokens, &scopedToken{
   300  				scope:   UnlimitedScope(),
   301  				token:   info.AccessToken,
   302  				expires: forever,
   303  			})
   304  		}
   305  		if info.Username != "" && info.Password != "" {
   306  			r.basic = &userPass{
   307  				username: info.Username,
   308  				password: info.Password,
   309  			}
   310  		}
   311  		return nil
   312  	}
   313  	r.initOnce.Do(func() {
   314  		r.initErr = inner()
   315  	})
   316  	return r.initErr
   317  }
   318  
   319  // acquireAccessToken tries to acquire an access token for authorizing a request.
   320  // The requiredScopeStr parameter indicates the scope that's definitely
   321  // required. This is a string because apparently some servers are picky
   322  // about getting exactly the same scope in the auth request that was
   323  // returned in the challenge. The wantScope parameter indicates
   324  // what scope might be required in the future.
   325  //
   326  // This method assumes that there has been a previous 401 response with
   327  // a Www-Authenticate: Bearer... header.
   328  func (r *registry) acquireAccessToken(ctx context.Context, requiredScope, wantScope Scope) (string, error) {
   329  	scope := requiredScope.Union(wantScope)
   330  	tok, err := r.acquireToken(ctx, scope)
   331  	if err != nil {
   332  		var herr ociregistry.HTTPError
   333  		if !errors.As(err, &herr) || herr.StatusCode() != http.StatusUnauthorized {
   334  			return "", err
   335  		}
   336  		// The documentation says this:
   337  		//
   338  		//	If the client only has a subset of the requested
   339  		// 	access it _must not be considered an error_ as it is
   340  		//	not the responsibility of the token server to
   341  		//	indicate authorization errors as part of this
   342  		//	workflow.
   343  		//
   344  		// However it's apparently not uncommon for servers to reject
   345  		// such requests anyway, so if we've got an unauthorized error
   346  		// and wantScope goes beyond requiredScope, it may be because
   347  		// the server is rejecting the request.
   348  		scope = requiredScope
   349  		tok, err = r.acquireToken(ctx, scope)
   350  		if err != nil {
   351  			return "", err
   352  		}
   353  		// TODO mark the registry as picky about tokens so we don't
   354  		// attempt twice every time?
   355  	}
   356  	if tok.RefreshToken != "" {
   357  		r.refreshToken = tok.RefreshToken
   358  	}
   359  	accessToken := tok.Token
   360  	if accessToken == "" {
   361  		accessToken = tok.AccessToken
   362  	}
   363  	if accessToken == "" {
   364  		return "", fmt.Errorf("no access token found in auth server response")
   365  	}
   366  	var expires time.Time
   367  	now := time.Now().UTC()
   368  	if tok.ExpiresIn == 0 {
   369  		expires = now.Add(60 * time.Second) // TODO link to where this is mentioned
   370  	} else {
   371  		expires = now.Add(time.Duration(tok.ExpiresIn) * time.Second)
   372  	}
   373  	r.accessTokens = append(r.accessTokens, &scopedToken{
   374  		scope:   scope,
   375  		token:   accessToken,
   376  		expires: expires,
   377  	})
   378  	// TODO persist the access token to save round trips when doing
   379  	// the authorization flow in a newly run executable.
   380  	return accessToken, nil
   381  }
   382  
   383  func (r *registry) acquireToken(ctx context.Context, scope Scope) (*wireToken, error) {
   384  	realm := r.wwwAuthenticate.params["realm"]
   385  	if realm == "" {
   386  		return nil, fmt.Errorf("malformed Www-Authenticate header (missing realm)")
   387  	}
   388  	if r.refreshToken != "" {
   389  		v := url.Values{}
   390  		v.Set("scope", scope.String())
   391  		if service := r.wwwAuthenticate.params["service"]; service != "" {
   392  			v.Set("service", service)
   393  		}
   394  		v.Set("client_id", oauthClientID)
   395  		v.Set("grant_type", "refresh_token")
   396  		v.Set("refresh_token", r.refreshToken)
   397  		req, err := http.NewRequestWithContext(ctx, "POST", realm, strings.NewReader(v.Encode()))
   398  		if err != nil {
   399  			return nil, fmt.Errorf("cannot form HTTP request to %q: %v", realm, err)
   400  		}
   401  		req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   402  		tok, err := r.doTokenRequest(req)
   403  		if err == nil {
   404  			return tok, nil
   405  		}
   406  		var herr ociregistry.HTTPError
   407  		if !errors.As(err, &herr) || herr.StatusCode() != http.StatusNotFound {
   408  			return tok, err
   409  		}
   410  		// The request to the endpoint returned 404 from the POST request,
   411  		// Note: Not all token servers implement oauth2, so fall
   412  		// back to using a GET with basic auth.
   413  		// See the Token documentation for the HTTP GET method supported by all token servers.
   414  		// TODO where in that documentation is this documented?
   415  	}
   416  	u, err := url.Parse(realm)
   417  	if err != nil {
   418  		return nil, fmt.Errorf("malformed Www-Authenticate header (malformed realm %q): %v", realm, err)
   419  	}
   420  	v := u.Query()
   421  	// TODO where is it documented that we should send multiple scope
   422  	// attributes rather than a single space-separated attribute as
   423  	// the POST method does?
   424  	v["scope"] = strings.Split(scope.String(), " ")
   425  	if service := r.wwwAuthenticate.params["service"]; service != "" {
   426  		// TODO the containerregistry code sets this even if it's empty.
   427  		// Is that better?
   428  		v.Set("service", service)
   429  	}
   430  	u.RawQuery = v.Encode()
   431  	req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
   432  	if err != nil {
   433  		return nil, err
   434  	}
   435  	// TODO if there's an unlimited-scope access token, the original code
   436  	// will use it as Bearer authorization at this point. If
   437  	// that's valid, why are we even acquiring another token?
   438  	if r.basic != nil {
   439  		req.SetBasicAuth(r.basic.username, r.basic.password)
   440  	}
   441  	return r.doTokenRequest(req)
   442  }
   443  
   444  // wireToken describes the JSON encoding used in the response to a token
   445  // acquisition method. The comments are taken from the [token docs]
   446  // and made available here for ease of reference.
   447  //
   448  // [token docs]: https://distribution.github.io/distribution/spec/auth/token/#token-response-fields
   449  type wireToken struct {
   450  	// Token holds an opaque Bearer token that clients should supply
   451  	// to subsequent requests in the Authorization header.
   452  	// AccessToken is provided for compatibility with OAuth 2.0: it's equivalent to Token.
   453  	// At least one of these fields must be specified, but both may also appear (for compatibility with older clients).
   454  	// When both are specified, they should be equivalent; if they differ the client's choice is undefined.
   455  	Token       string `json:"token"`
   456  	AccessToken string `json:"access_token,omitempty"`
   457  
   458  	// Refresh token optionally holds a token which can be used to
   459  	// get additional access tokens for the same subject with different scopes.
   460  	// This token should be kept secure by the client and only sent
   461  	// to the authorization server which issues bearer tokens. This
   462  	// field will only be set when `offline_token=true` is provided
   463  	// in the request.
   464  	RefreshToken string `json:"refresh_token"`
   465  
   466  	// ExpiresIn holds the duration in seconds since the token was
   467  	// issued that it will remain valid. When omitted, this defaults
   468  	// to 60 seconds. For compatibility with older clients, a token
   469  	// should never be returned with less than 60 seconds to live.
   470  	ExpiresIn int `json:"expires_in"`
   471  }
   472  
   473  func (r *registry) doTokenRequest(req *http.Request) (*wireToken, error) {
   474  	client := &http.Client{
   475  		Transport: r.transport,
   476  	}
   477  	resp, err := client.Do(req)
   478  	if err != nil {
   479  		return nil, err
   480  	}
   481  	defer resp.Body.Close()
   482  	data, bodyErr := io.ReadAll(resp.Body)
   483  	if resp.StatusCode != http.StatusOK {
   484  		return nil, ociregistry.NewHTTPError(nil, resp.StatusCode, resp, data)
   485  	}
   486  	if bodyErr != nil {
   487  		return nil, fmt.Errorf("error reading response body: %v", err)
   488  	}
   489  	var tok wireToken
   490  	if err := json.Unmarshal(data, &tok); err != nil {
   491  		return nil, fmt.Errorf("malformed JSON token in response: %v", err)
   492  	}
   493  	return &tok, nil
   494  }
   495  
   496  // deleteExpiredTokens removes all tokens from r that expire after the given
   497  // time.
   498  // TODO ask the store to remove expired tokens?
   499  func (r *registry) deleteExpiredTokens(now time.Time) {
   500  	r.accessTokens = slices.DeleteFunc(r.accessTokens, func(tok *scopedToken) bool {
   501  		return now.After(tok.expires)
   502  	})
   503  }
   504  
   505  func (r *registry) accessTokenForScope(scope Scope) *scopedToken {
   506  	for _, tok := range r.accessTokens {
   507  		if tok.scope.Contains(scope) {
   508  			// TODO prefer tokens with less scope?
   509  			return tok
   510  		}
   511  	}
   512  	return nil
   513  }
   514  
   515  type emptyConfig struct{}
   516  
   517  func (emptyConfig) EntryForRegistry(host string) (ConfigEntry, error) {
   518  	return ConfigEntry{}, nil
   519  }