github.com/dtroyer-salad/og2/v2@v2.0.0-20240412154159-c47231610877/registry/remote/auth/client.go (about)

     1  /*
     2  Copyright The ORAS Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  // Package auth provides authentication for a client to a remote registry.
    17  package auth
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/url"
    28  	"strings"
    29  
    30  	"oras.land/oras-go/v2/registry/remote/internal/errutil"
    31  	"oras.land/oras-go/v2/registry/remote/retry"
    32  )
    33  
    34  // ErrBasicCredentialNotFound is returned  when the credential is not found for
    35  // basic auth.
    36  var ErrBasicCredentialNotFound = errors.New("basic credential not found")
    37  
    38  // DefaultClient is the default auth-decorated client.
    39  var DefaultClient = &Client{
    40  	Client: retry.DefaultClient,
    41  	Header: http.Header{
    42  		"User-Agent": {"oras-go"},
    43  	},
    44  	Cache: DefaultCache,
    45  }
    46  
    47  // maxResponseBytes specifies the default limit on how many response bytes are
    48  // allowed in the server's response from authorization service servers.
    49  // A typical response message from authorization service servers is around 1 to
    50  // 4 KiB. Since the size of a token must be smaller than the HTTP header size
    51  // limit, which is usually 16 KiB. As specified by the distribution, the
    52  // response may contain 2 identical tokens, that is, 16 x 2 = 32 KiB.
    53  // Hence, 128 KiB should be sufficient.
    54  // References: https://docs.docker.com/registry/spec/auth/token/
    55  var maxResponseBytes int64 = 128 * 1024 // 128 KiB
    56  
    57  // defaultClientID specifies the default client ID used in OAuth2.
    58  // See also ClientID.
    59  var defaultClientID = "oras-go"
    60  
    61  // CredentialFunc represents a function that resolves the credential for the
    62  // given registry (i.e. host:port).
    63  //
    64  // [EmptyCredential] is a valid return value and should not be considered as
    65  // an error.
    66  type CredentialFunc func(ctx context.Context, hostport string) (Credential, error)
    67  
    68  // StaticCredential specifies static credentials for the given host.
    69  func StaticCredential(registry string, cred Credential) CredentialFunc {
    70  	if registry == "docker.io" {
    71  		// it is expected that traffic targeting "docker.io" will be redirected
    72  		// to "registry-1.docker.io"
    73  		// reference: https://github.com/moby/moby/blob/v24.0.0-beta.2/registry/config.go#L25-L48
    74  		registry = "registry-1.docker.io"
    75  	}
    76  	return func(_ context.Context, hostport string) (Credential, error) {
    77  		if hostport == registry {
    78  			return cred, nil
    79  		}
    80  		return EmptyCredential, nil
    81  	}
    82  }
    83  
    84  // Client is an auth-decorated HTTP client.
    85  // Its zero value is a usable client that uses http.DefaultClient with no cache.
    86  type Client struct {
    87  	// Client is the underlying HTTP client used to access the remote
    88  	// server.
    89  	// If nil, http.DefaultClient is used.
    90  	// It is possible to use the default retry client from the package
    91  	// `oras.land/oras-go/v2/registry/remote/retry`. That client is already available
    92  	// in the DefaultClient.
    93  	// It is also possible to use a custom client. For example, github.com/hashicorp/go-retryablehttp
    94  	// is a popular HTTP client that supports retries.
    95  	Client *http.Client
    96  
    97  	// Header contains the custom headers to be added to each request.
    98  	Header http.Header
    99  
   100  	// Credential specifies the function for resolving the credential for the
   101  	// given registry (i.e. host:port).
   102  	// EmptyCredential is a valid return value and should not be considered as
   103  	// an error.
   104  	// If nil, the credential is always resolved to EmptyCredential.
   105  	Credential CredentialFunc
   106  
   107  	// Cache caches credentials for direct accessing the remote registry.
   108  	// If nil, no cache is used.
   109  	Cache Cache
   110  
   111  	// ClientID used in fetching OAuth2 token as a required field.
   112  	// If empty, a default client ID is used.
   113  	// Reference: https://docs.docker.com/registry/spec/auth/oauth/#getting-a-token
   114  	ClientID string
   115  
   116  	// ForceAttemptOAuth2 controls whether to follow OAuth2 with password grant
   117  	// instead the distribution spec when authenticating using username and
   118  	// password.
   119  	// References:
   120  	// - https://docs.docker.com/registry/spec/auth/jwt/
   121  	// - https://docs.docker.com/registry/spec/auth/oauth/
   122  	ForceAttemptOAuth2 bool
   123  }
   124  
   125  // client returns an HTTP client used to access the remote registry.
   126  // http.DefaultClient is return if the client is not configured.
   127  func (c *Client) client() *http.Client {
   128  	if c.Client == nil {
   129  		return http.DefaultClient
   130  	}
   131  	return c.Client
   132  }
   133  
   134  // send adds headers to the request and sends the request to the remote server.
   135  func (c *Client) send(req *http.Request) (*http.Response, error) {
   136  	for key, values := range c.Header {
   137  		req.Header[key] = append(req.Header[key], values...)
   138  	}
   139  	return c.client().Do(req)
   140  }
   141  
   142  // credential resolves the credential for the given registry.
   143  func (c *Client) credential(ctx context.Context, reg string) (Credential, error) {
   144  	if c.Credential == nil {
   145  		return EmptyCredential, nil
   146  	}
   147  	return c.Credential(ctx, reg)
   148  }
   149  
   150  // cache resolves the cache.
   151  // noCache is return if the cache is not configured.
   152  func (c *Client) cache() Cache {
   153  	if c.Cache == nil {
   154  		return noCache{}
   155  	}
   156  	return c.Cache
   157  }
   158  
   159  // SetUserAgent sets the user agent for all out-going requests.
   160  func (c *Client) SetUserAgent(userAgent string) {
   161  	if c.Header == nil {
   162  		c.Header = http.Header{}
   163  	}
   164  	c.Header.Set("User-Agent", userAgent)
   165  }
   166  
   167  // Do sends the request to the remote server, attempting to resolve
   168  // authentication if 'Authorization' header is not set.
   169  //
   170  // On authentication failure due to bad credential,
   171  //   - Do returns error if it fails to fetch token for bearer auth.
   172  //   - Do returns the registry response without error for basic auth.
   173  func (c *Client) Do(originalReq *http.Request) (*http.Response, error) {
   174  	if auth := originalReq.Header.Get("Authorization"); auth != "" {
   175  		return c.send(originalReq)
   176  	}
   177  
   178  	ctx := originalReq.Context()
   179  	req := originalReq.Clone(ctx)
   180  
   181  	// attempt cached auth token
   182  	var attemptedKey string
   183  	cache := c.cache()
   184  	host := originalReq.Host
   185  	scheme, err := cache.GetScheme(ctx, host)
   186  	if err == nil {
   187  		switch scheme {
   188  		case SchemeBasic:
   189  			token, err := cache.GetToken(ctx, host, SchemeBasic, "")
   190  			if err == nil {
   191  				req.Header.Set("Authorization", "Basic "+token)
   192  			}
   193  		case SchemeBearer:
   194  			scopes := GetAllScopesForHost(ctx, host)
   195  			attemptedKey = strings.Join(scopes, " ")
   196  			token, err := cache.GetToken(ctx, host, SchemeBearer, attemptedKey)
   197  			if err == nil {
   198  				req.Header.Set("Authorization", "Bearer "+token)
   199  			}
   200  		}
   201  	}
   202  
   203  	resp, err := c.send(req)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  	if resp.StatusCode != http.StatusUnauthorized {
   208  		return resp, nil
   209  	}
   210  
   211  	// attempt again with credentials for recognized schemes
   212  	challenge := resp.Header.Get("Www-Authenticate")
   213  	scheme, params := parseChallenge(challenge)
   214  	switch scheme {
   215  	case SchemeBasic:
   216  		resp.Body.Close()
   217  
   218  		token, err := cache.Set(ctx, host, SchemeBasic, "", func(ctx context.Context) (string, error) {
   219  			return c.fetchBasicAuth(ctx, host)
   220  		})
   221  		if err != nil {
   222  			return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
   223  		}
   224  
   225  		req = originalReq.Clone(ctx)
   226  		req.Header.Set("Authorization", "Basic "+token)
   227  	case SchemeBearer:
   228  		resp.Body.Close()
   229  
   230  		scopes := GetAllScopesForHost(ctx, host)
   231  		if paramScope := params["scope"]; paramScope != "" {
   232  			// merge hinted scopes with challenged scopes
   233  			scopes = append(scopes, strings.Split(paramScope, " ")...)
   234  			scopes = CleanScopes(scopes)
   235  		}
   236  		key := strings.Join(scopes, " ")
   237  
   238  		// attempt the cache again if there is a scope change
   239  		if key != attemptedKey {
   240  			if token, err := cache.GetToken(ctx, host, SchemeBearer, key); err == nil {
   241  				req = originalReq.Clone(ctx)
   242  				req.Header.Set("Authorization", "Bearer "+token)
   243  				if err := rewindRequestBody(req); err != nil {
   244  					return nil, err
   245  				}
   246  
   247  				resp, err := c.send(req)
   248  				if err != nil {
   249  					return nil, err
   250  				}
   251  				if resp.StatusCode != http.StatusUnauthorized {
   252  					return resp, nil
   253  				}
   254  				resp.Body.Close()
   255  			}
   256  		}
   257  
   258  		// attempt with credentials
   259  		realm := params["realm"]
   260  		service := params["service"]
   261  		token, err := cache.Set(ctx, host, SchemeBearer, key, func(ctx context.Context) (string, error) {
   262  			return c.fetchBearerToken(ctx, host, realm, service, scopes)
   263  		})
   264  		if err != nil {
   265  			return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
   266  		}
   267  
   268  		req = originalReq.Clone(ctx)
   269  		req.Header.Set("Authorization", "Bearer "+token)
   270  	default:
   271  		return resp, nil
   272  	}
   273  	if err := rewindRequestBody(req); err != nil {
   274  		return nil, err
   275  	}
   276  
   277  	return c.send(req)
   278  }
   279  
   280  // fetchBasicAuth fetches a basic auth token for the basic challenge.
   281  func (c *Client) fetchBasicAuth(ctx context.Context, registry string) (string, error) {
   282  	cred, err := c.credential(ctx, registry)
   283  	if err != nil {
   284  		return "", fmt.Errorf("failed to resolve credential: %w", err)
   285  	}
   286  	if cred == EmptyCredential {
   287  		return "", ErrBasicCredentialNotFound
   288  	}
   289  	if cred.Username == "" || cred.Password == "" {
   290  		return "", errors.New("missing username or password for basic auth")
   291  	}
   292  	auth := cred.Username + ":" + cred.Password
   293  	return base64.StdEncoding.EncodeToString([]byte(auth)), nil
   294  }
   295  
   296  // fetchBearerToken fetches an access token for the bearer challenge.
   297  func (c *Client) fetchBearerToken(ctx context.Context, registry, realm, service string, scopes []string) (string, error) {
   298  	cred, err := c.credential(ctx, registry)
   299  	if err != nil {
   300  		return "", err
   301  	}
   302  	if cred.AccessToken != "" {
   303  		return cred.AccessToken, nil
   304  	}
   305  	if cred == EmptyCredential || (cred.RefreshToken == "" && !c.ForceAttemptOAuth2) {
   306  		return c.fetchDistributionToken(ctx, realm, service, scopes, cred.Username, cred.Password)
   307  	}
   308  	return c.fetchOAuth2Token(ctx, realm, service, scopes, cred)
   309  }
   310  
   311  // fetchDistributionToken fetches an access token as defined by the distribution
   312  // specification.
   313  // It fetches anonymous tokens if no credential is provided.
   314  // References:
   315  // - https://docs.docker.com/registry/spec/auth/jwt/
   316  // - https://docs.docker.com/registry/spec/auth/token/
   317  func (c *Client) fetchDistributionToken(ctx context.Context, realm, service string, scopes []string, username, password string) (string, error) {
   318  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil)
   319  	if err != nil {
   320  		return "", err
   321  	}
   322  	if username != "" || password != "" {
   323  		req.SetBasicAuth(username, password)
   324  	}
   325  	q := req.URL.Query()
   326  	if service != "" {
   327  		q.Set("service", service)
   328  	}
   329  	for _, scope := range scopes {
   330  		q.Add("scope", scope)
   331  	}
   332  	req.URL.RawQuery = q.Encode()
   333  
   334  	resp, err := c.send(req)
   335  	if err != nil {
   336  		return "", err
   337  	}
   338  	defer resp.Body.Close()
   339  	if resp.StatusCode != http.StatusOK {
   340  		return "", errutil.ParseErrorResponse(resp)
   341  	}
   342  
   343  	// As specified in https://docs.docker.com/registry/spec/auth/token/ section
   344  	// "Token Response Fields", the token is either in `token` or
   345  	// `access_token`. If both present, they are identical.
   346  	var result struct {
   347  		Token       string `json:"token"`
   348  		AccessToken string `json:"access_token"`
   349  	}
   350  	lr := io.LimitReader(resp.Body, maxResponseBytes)
   351  	if err := json.NewDecoder(lr).Decode(&result); err != nil {
   352  		return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
   353  	}
   354  	if result.AccessToken != "" {
   355  		return result.AccessToken, nil
   356  	}
   357  	if result.Token != "" {
   358  		return result.Token, nil
   359  	}
   360  	return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
   361  }
   362  
   363  // fetchOAuth2Token fetches an OAuth2 access token.
   364  // Reference: https://docs.docker.com/registry/spec/auth/oauth/
   365  func (c *Client) fetchOAuth2Token(ctx context.Context, realm, service string, scopes []string, cred Credential) (string, error) {
   366  	form := url.Values{}
   367  	if cred.RefreshToken != "" {
   368  		form.Set("grant_type", "refresh_token")
   369  		form.Set("refresh_token", cred.RefreshToken)
   370  	} else if cred.Username != "" && cred.Password != "" {
   371  		form.Set("grant_type", "password")
   372  		form.Set("username", cred.Username)
   373  		form.Set("password", cred.Password)
   374  	} else {
   375  		return "", errors.New("missing username or password for bearer auth")
   376  	}
   377  	form.Set("service", service)
   378  	clientID := c.ClientID
   379  	if clientID == "" {
   380  		clientID = defaultClientID
   381  	}
   382  	form.Set("client_id", clientID)
   383  	if len(scopes) != 0 {
   384  		form.Set("scope", strings.Join(scopes, " "))
   385  	}
   386  	body := strings.NewReader(form.Encode())
   387  
   388  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, body)
   389  	if err != nil {
   390  		return "", err
   391  	}
   392  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   393  
   394  	resp, err := c.send(req)
   395  	if err != nil {
   396  		return "", err
   397  	}
   398  	defer resp.Body.Close()
   399  	if resp.StatusCode != http.StatusOK {
   400  		return "", errutil.ParseErrorResponse(resp)
   401  	}
   402  
   403  	var result struct {
   404  		AccessToken string `json:"access_token"`
   405  	}
   406  	lr := io.LimitReader(resp.Body, maxResponseBytes)
   407  	if err := json.NewDecoder(lr).Decode(&result); err != nil {
   408  		return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
   409  	}
   410  	if result.AccessToken != "" {
   411  		return result.AccessToken, nil
   412  	}
   413  	return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
   414  }
   415  
   416  // rewindRequestBody tries to rewind the request body if exists.
   417  func rewindRequestBody(req *http.Request) error {
   418  	if req.Body == nil || req.Body == http.NoBody {
   419  		return nil
   420  	}
   421  	if req.GetBody == nil {
   422  		return fmt.Errorf("%s %q: request body is not rewindable", req.Method, req.URL)
   423  	}
   424  	body, err := req.GetBody()
   425  	if err != nil {
   426  		return fmt.Errorf("%s %q: failed to get request body: %w", req.Method, req.URL, err)
   427  	}
   428  	req.Body = body
   429  	return nil
   430  }