github.com/opcr-io/oras-go/v2@v2.0.0-20231122155130-eb4260d8a0ae/registry/remote/auth/client.go (about)

     1  /*
     2  Copyright The ORAS Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  // Package auth provides authentication for a client to a remote registry.
    17  package auth
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/url"
    28  	"strings"
    29  
    30  	"github.com/opcr-io/oras-go/v2/registry/remote/internal/errutil"
    31  	"github.com/opcr-io/oras-go/v2/registry/remote/retry"
    32  )
    33  
    34  // DefaultClient is the default auth-decorated client.
    35  var DefaultClient = &Client{
    36  	Client: retry.DefaultClient,
    37  	Header: http.Header{
    38  		"User-Agent": {"oras-go"},
    39  	},
    40  	Cache: DefaultCache,
    41  }
    42  
    43  // maxResponseBytes specifies the default limit on how many response bytes are
    44  // allowed in the server's response from authorization service servers.
    45  // A typical response message from authorization service servers is around 1 to
    46  // 4 KiB. Since the size of a token must be smaller than the HTTP header size
    47  // limit, which is usually 16 KiB. As specified by the distribution, the
    48  // response may contain 2 identical tokens, that is, 16 x 2 = 32 KiB.
    49  // Hence, 128 KiB should be sufficient.
    50  // References: https://docs.docker.com/registry/spec/auth/token/
    51  var maxResponseBytes int64 = 128 * 1024 // 128 KiB
    52  
    53  // defaultClientID specifies the default client ID used in OAuth2.
    54  // See also ClientID.
    55  var defaultClientID = "oras-go"
    56  
    57  // StaticCredential specifies static credentials for the given host.
    58  func StaticCredential(registry string, cred Credential) func(context.Context, string) (Credential, error) {
    59  	return func(_ context.Context, target string) (Credential, error) {
    60  		if target == registry {
    61  			return cred, nil
    62  		}
    63  		return EmptyCredential, nil
    64  	}
    65  }
    66  
    67  // Client is an auth-decorated HTTP client.
    68  // Its zero value is a usable client that uses http.DefaultClient with no cache.
    69  type Client struct {
    70  	// Client is the underlying HTTP client used to access the remote
    71  	// server.
    72  	// If nil, http.DefaultClient is used.
    73  	// It is possible to use the default retry client from the package
    74  	// `github.com/opcr-io/oras-go/v2/registry/remote/retry`. That client is already available
    75  	// in the DefaultClient.
    76  	// It is also possible to use a custom client. For example, github.com/hashicorp/go-retryablehttp
    77  	// is a popular HTTP client that supports retries.
    78  	Client *http.Client
    79  
    80  	// Header contains the custom headers to be added to each request.
    81  	Header http.Header
    82  
    83  	// Credential specifies the function for resolving the credential for the
    84  	// given registry (i.e. host:port).
    85  	// `EmptyCredential` is a valid return value and should not be considered as
    86  	// an error.
    87  	// If nil, the credential is always resolved to `EmptyCredential`.
    88  	Credential func(context.Context, string) (Credential, error)
    89  
    90  	// Cache caches credentials for direct accessing the remote registry.
    91  	// If nil, no cache is used.
    92  	Cache Cache
    93  
    94  	// ClientID used in fetching OAuth2 token as a required field.
    95  	// If empty, a default client ID is used.
    96  	// Reference: https://docs.docker.com/registry/spec/auth/oauth/#getting-a-token
    97  	ClientID string
    98  
    99  	// ForceAttemptOAuth2 controls whether to follow OAuth2 with password grant
   100  	// instead the distribution spec when authenticating using username and
   101  	// password.
   102  	// References:
   103  	// - https://docs.docker.com/registry/spec/auth/jwt/
   104  	// - https://docs.docker.com/registry/spec/auth/oauth/
   105  	ForceAttemptOAuth2 bool
   106  }
   107  
   108  // client returns an HTTP client used to access the remote registry.
   109  // http.DefaultClient is return if the client is not configured.
   110  func (c *Client) client() *http.Client {
   111  	if c.Client == nil {
   112  		return http.DefaultClient
   113  	}
   114  	return c.Client
   115  }
   116  
   117  // send adds headers to the request and sends the request to the remote server.
   118  func (c *Client) send(req *http.Request) (*http.Response, error) {
   119  	for key, values := range c.Header {
   120  		req.Header[key] = append(req.Header[key], values...)
   121  	}
   122  	return c.client().Do(req)
   123  }
   124  
   125  // credential resolves the credential for the given registry.
   126  func (c *Client) credential(ctx context.Context, reg string) (Credential, error) {
   127  	if c.Credential == nil {
   128  		return EmptyCredential, nil
   129  	}
   130  	return c.Credential(ctx, reg)
   131  }
   132  
   133  // cache resolves the cache.
   134  // noCache is return if the cache is not configured.
   135  func (c *Client) cache() Cache {
   136  	if c.Cache == nil {
   137  		return noCache{}
   138  	}
   139  	return c.Cache
   140  }
   141  
   142  // SetUserAgent sets the user agent for all out-going requests.
   143  func (c *Client) SetUserAgent(userAgent string) {
   144  	if c.Header == nil {
   145  		c.Header = http.Header{}
   146  	}
   147  	c.Header.Set("User-Agent", userAgent)
   148  }
   149  
   150  // Do sends the request to the remote server, attempting to resolve
   151  // authentication if 'Authorization' header is not set.
   152  //
   153  // On authentication failure due to bad credential,
   154  //   - Do returns error if it fails to fetch token for bearer auth.
   155  //   - Do returns the registry response without error for basic auth.
   156  func (c *Client) Do(originalReq *http.Request) (*http.Response, error) {
   157  	if auth := originalReq.Header.Get("Authorization"); auth != "" {
   158  		return c.send(originalReq)
   159  	}
   160  
   161  	ctx := originalReq.Context()
   162  	req := originalReq.Clone(ctx)
   163  
   164  	// attempt cached auth token
   165  	var attemptedKey string
   166  	cache := c.cache()
   167  	registry := originalReq.Host
   168  	scheme, err := cache.GetScheme(ctx, registry)
   169  	if err == nil {
   170  		switch scheme {
   171  		case SchemeBasic:
   172  			token, err := cache.GetToken(ctx, registry, SchemeBasic, "")
   173  			if err == nil {
   174  				req.Header.Set("Authorization", "Basic "+token)
   175  			}
   176  		case SchemeBearer:
   177  			scopes := GetScopes(ctx)
   178  			attemptedKey = strings.Join(scopes, " ")
   179  			token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey)
   180  			if err == nil {
   181  				req.Header.Set("Authorization", "Bearer "+token)
   182  			}
   183  		}
   184  	}
   185  
   186  	resp, err := c.send(req)
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	if resp.StatusCode != http.StatusUnauthorized {
   191  		return resp, nil
   192  	}
   193  
   194  	// attempt again with credentials for recognized schemes
   195  	challenge := resp.Header.Get("Www-Authenticate")
   196  	scheme, params := parseChallenge(challenge)
   197  	switch scheme {
   198  	case SchemeBasic:
   199  		resp.Body.Close()
   200  
   201  		token, err := cache.Set(ctx, registry, SchemeBasic, "", func(ctx context.Context) (string, error) {
   202  			return c.fetchBasicAuth(ctx, registry)
   203  		})
   204  		if err != nil {
   205  			return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
   206  		}
   207  
   208  		req = originalReq.Clone(ctx)
   209  		req.Header.Set("Authorization", "Basic "+token)
   210  	case SchemeBearer:
   211  		resp.Body.Close()
   212  
   213  		// merge hinted scopes with challenged scopes
   214  		scopes := GetScopes(ctx)
   215  		if scope := params["scope"]; scope != "" {
   216  			scopes = append(scopes, strings.Split(scope, " ")...)
   217  			scopes = CleanScopes(scopes)
   218  		}
   219  		key := strings.Join(scopes, " ")
   220  
   221  		// attempt the cache again if there is a scope change
   222  		if key != attemptedKey {
   223  			if token, err := cache.GetToken(ctx, registry, SchemeBearer, key); err == nil {
   224  				req = originalReq.Clone(ctx)
   225  				req.Header.Set("Authorization", "Bearer "+token)
   226  				if err := rewindRequestBody(req); err != nil {
   227  					return nil, err
   228  				}
   229  
   230  				resp, err := c.send(req)
   231  				if err != nil {
   232  					return nil, err
   233  				}
   234  				if resp.StatusCode != http.StatusUnauthorized {
   235  					return resp, nil
   236  				}
   237  				resp.Body.Close()
   238  			}
   239  		}
   240  
   241  		// attempt with credentials
   242  		realm := params["realm"]
   243  		service := params["service"]
   244  		token, err := cache.Set(ctx, registry, SchemeBearer, key, func(ctx context.Context) (string, error) {
   245  			return c.fetchBearerToken(ctx, registry, realm, service, scopes)
   246  		})
   247  		if err != nil {
   248  			return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
   249  		}
   250  
   251  		req = originalReq.Clone(ctx)
   252  		req.Header.Set("Authorization", "Bearer "+token)
   253  	default:
   254  		return resp, nil
   255  	}
   256  	if err := rewindRequestBody(req); err != nil {
   257  		return nil, err
   258  	}
   259  
   260  	return c.send(req)
   261  }
   262  
   263  // fetchBasicAuth fetches a basic auth token for the basic challenge.
   264  func (c *Client) fetchBasicAuth(ctx context.Context, registry string) (string, error) {
   265  	cred, err := c.credential(ctx, registry)
   266  	if err != nil {
   267  		return "", fmt.Errorf("failed to resolve credential: %w", err)
   268  	}
   269  	if cred == EmptyCredential {
   270  		return "", errors.New("credential required for basic auth")
   271  	}
   272  	if cred.Username == "" || cred.Password == "" {
   273  		return "", errors.New("missing username or password for basic auth")
   274  	}
   275  	auth := cred.Username + ":" + cred.Password
   276  	return base64.StdEncoding.EncodeToString([]byte(auth)), nil
   277  }
   278  
   279  // fetchBearerToken fetches an access token for the bearer challenge.
   280  func (c *Client) fetchBearerToken(ctx context.Context, registry, realm, service string, scopes []string) (string, error) {
   281  	cred, err := c.credential(ctx, registry)
   282  	if err != nil {
   283  		return "", err
   284  	}
   285  	if cred.AccessToken != "" {
   286  		return cred.AccessToken, nil
   287  	}
   288  	if cred == EmptyCredential || (cred.RefreshToken == "" && !c.ForceAttemptOAuth2) {
   289  		return c.fetchDistributionToken(ctx, realm, service, scopes, cred.Username, cred.Password)
   290  	}
   291  	return c.fetchOAuth2Token(ctx, realm, service, scopes, cred)
   292  }
   293  
   294  // fetchDistributionToken fetches an access token as defined by the distribution
   295  // specification.
   296  // It fetches anonymous tokens if no credential is provided.
   297  // References:
   298  // - https://docs.docker.com/registry/spec/auth/jwt/
   299  // - https://docs.docker.com/registry/spec/auth/token/
   300  func (c *Client) fetchDistributionToken(ctx context.Context, realm, service string, scopes []string, username, password string) (string, error) {
   301  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil)
   302  	if err != nil {
   303  		return "", err
   304  	}
   305  	if username != "" || password != "" {
   306  		req.SetBasicAuth(username, password)
   307  	}
   308  	q := req.URL.Query()
   309  	if service != "" {
   310  		q.Set("service", service)
   311  	}
   312  	for _, scope := range scopes {
   313  		q.Add("scope", scope)
   314  	}
   315  	req.URL.RawQuery = q.Encode()
   316  
   317  	resp, err := c.send(req)
   318  	if err != nil {
   319  		return "", err
   320  	}
   321  	defer resp.Body.Close()
   322  	if resp.StatusCode != http.StatusOK {
   323  		return "", errutil.ParseErrorResponse(resp)
   324  	}
   325  
   326  	// As specified in https://docs.docker.com/registry/spec/auth/token/ section
   327  	// "Token Response Fields", the token is either in `token` or
   328  	// `access_token`. If both present, they are identical.
   329  	var result struct {
   330  		Token       string `json:"token"`
   331  		AccessToken string `json:"access_token"`
   332  	}
   333  	lr := io.LimitReader(resp.Body, maxResponseBytes)
   334  	if err := json.NewDecoder(lr).Decode(&result); err != nil {
   335  		return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
   336  	}
   337  	if result.AccessToken != "" {
   338  		return result.AccessToken, nil
   339  	}
   340  	if result.Token != "" {
   341  		return result.Token, nil
   342  	}
   343  	return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
   344  }
   345  
   346  // fetchOAuth2Token fetches an OAuth2 access token.
   347  // Reference: https://docs.docker.com/registry/spec/auth/oauth/
   348  func (c *Client) fetchOAuth2Token(ctx context.Context, realm, service string, scopes []string, cred Credential) (string, error) {
   349  	form := url.Values{}
   350  	if cred.RefreshToken != "" {
   351  		form.Set("grant_type", "refresh_token")
   352  		form.Set("refresh_token", cred.RefreshToken)
   353  	} else if cred.Username != "" && cred.Password != "" {
   354  		form.Set("grant_type", "password")
   355  		form.Set("username", cred.Username)
   356  		form.Set("password", cred.Password)
   357  	} else {
   358  		return "", errors.New("missing username or password for bearer auth")
   359  	}
   360  	form.Set("service", service)
   361  	clientID := c.ClientID
   362  	if clientID == "" {
   363  		clientID = defaultClientID
   364  	}
   365  	form.Set("client_id", clientID)
   366  	if len(scopes) != 0 {
   367  		form.Set("scope", strings.Join(scopes, " "))
   368  	}
   369  	body := strings.NewReader(form.Encode())
   370  
   371  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, body)
   372  	if err != nil {
   373  		return "", err
   374  	}
   375  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   376  
   377  	resp, err := c.send(req)
   378  	if err != nil {
   379  		return "", err
   380  	}
   381  	defer resp.Body.Close()
   382  	if resp.StatusCode != http.StatusOK {
   383  		return "", errutil.ParseErrorResponse(resp)
   384  	}
   385  
   386  	var result struct {
   387  		AccessToken string `json:"access_token"`
   388  	}
   389  	lr := io.LimitReader(resp.Body, maxResponseBytes)
   390  	if err := json.NewDecoder(lr).Decode(&result); err != nil {
   391  		return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
   392  	}
   393  	if result.AccessToken != "" {
   394  		return result.AccessToken, nil
   395  	}
   396  	return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
   397  }
   398  
   399  // rewindRequestBody tries to rewind the request body if exists.
   400  func rewindRequestBody(req *http.Request) error {
   401  	if req.Body == nil || req.Body == http.NoBody {
   402  		return nil
   403  	}
   404  	if req.GetBody == nil {
   405  		return fmt.Errorf("%s %q: request body is not rewindable", req.Method, req.URL)
   406  	}
   407  	body, err := req.GetBody()
   408  	if err != nil {
   409  		return fmt.Errorf("%s %q: failed to get request body: %w", req.Method, req.URL, err)
   410  	}
   411  	req.Body = body
   412  	return nil
   413  }