zotregistry.dev/zot@v1.4.4-0.20240314164342-eec277e14d20/pkg/extensions/sync/httpclient/client.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"net/url"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	zerr "zotregistry.dev/zot/errors"
    15  	"zotregistry.dev/zot/pkg/common"
    16  	"zotregistry.dev/zot/pkg/log"
    17  )
    18  
    19  const (
    20  	minimumTokenLifetimeSeconds = 60 // in seconds
    21  	pingTimeout                 = 5 * time.Second
    22  	// tokenBuffer is used to renew a token before it actually expires
    23  	// to account for the time to process requests on the server.
    24  	tokenBuffer = 5 * time.Second
    25  )
    26  
    27  type authType int
    28  
    29  const (
    30  	noneAuth authType = iota
    31  	basicAuth
    32  	tokenAuth
    33  )
    34  
    35  type challengeParams struct {
    36  	realm   string
    37  	service string
    38  	scope   string
    39  	err     string
    40  }
    41  
    42  type bearerToken struct {
    43  	Token          string    `json:"token"`        //nolint: tagliatelle
    44  	AccessToken    string    `json:"access_token"` //nolint: tagliatelle
    45  	ExpiresIn      int       `json:"expires_in"`   //nolint: tagliatelle
    46  	IssuedAt       time.Time `json:"issued_at"`    //nolint: tagliatelle
    47  	expirationTime time.Time
    48  }
    49  
    50  func (token *bearerToken) isExpired() bool {
    51  	// use tokenBuffer to expire it a bit earlier
    52  	return time.Now().After(token.expirationTime.Add(-1 * tokenBuffer))
    53  }
    54  
    55  type Config struct {
    56  	URL       string
    57  	Username  string
    58  	Password  string
    59  	CertDir   string
    60  	TLSVerify bool
    61  }
    62  
    63  type Client struct {
    64  	config   *Config
    65  	client   *http.Client
    66  	url      *url.URL
    67  	authType authType
    68  	cache    *TokenCache
    69  	lock     *sync.RWMutex
    70  	log      log.Logger
    71  }
    72  
    73  func New(config Config, log log.Logger) (*Client, error) {
    74  	client := &Client{log: log, lock: new(sync.RWMutex)}
    75  
    76  	client.cache = NewTokenCache()
    77  
    78  	if err := client.SetConfig(config); err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	return client, nil
    83  }
    84  
    85  func (httpClient *Client) GetConfig() *Config {
    86  	httpClient.lock.RLock()
    87  	defer httpClient.lock.RUnlock()
    88  
    89  	return httpClient.config
    90  }
    91  
    92  func (httpClient *Client) GetHostname() string {
    93  	httpClient.lock.RLock()
    94  	defer httpClient.lock.RUnlock()
    95  
    96  	return httpClient.url.Host
    97  }
    98  
    99  func (httpClient *Client) GetBaseURL() string {
   100  	httpClient.lock.RLock()
   101  	defer httpClient.lock.RUnlock()
   102  
   103  	return httpClient.url.String()
   104  }
   105  
   106  func (httpClient *Client) SetConfig(config Config) error {
   107  	httpClient.lock.Lock()
   108  	defer httpClient.lock.Unlock()
   109  
   110  	clientURL, err := url.Parse(config.URL)
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	httpClient.url = clientURL
   116  
   117  	client, err := common.CreateHTTPClient(config.TLSVerify, clientURL.Host, config.CertDir)
   118  	if err != nil {
   119  		return err
   120  	}
   121  
   122  	httpClient.client = client
   123  	httpClient.config = &config
   124  
   125  	return nil
   126  }
   127  
   128  func (httpClient *Client) Ping() bool {
   129  	httpClient.lock.Lock()
   130  	defer httpClient.lock.Unlock()
   131  
   132  	pingURL := *httpClient.url
   133  
   134  	pingURL = *pingURL.JoinPath("/v2/")
   135  
   136  	// for the ping function we want to timeout fast
   137  	ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
   138  	defer cancel()
   139  
   140  	//nolint: bodyclose
   141  	resp, _, err := httpClient.get(ctx, pingURL.String(), false)
   142  	if err != nil {
   143  		return false
   144  	}
   145  
   146  	httpClient.getAuthType(resp)
   147  
   148  	if resp.StatusCode >= http.StatusOK && resp.StatusCode <= http.StatusForbidden {
   149  		return true
   150  	}
   151  
   152  	httpClient.log.Error().Str("url", pingURL.String()).Int("statusCode", resp.StatusCode).
   153  		Str("component", "sync").Msg("failed to ping registry")
   154  
   155  	return false
   156  }
   157  
   158  func (httpClient *Client) MakeGetRequest(ctx context.Context, resultPtr interface{}, mediaType string,
   159  	route ...string,
   160  ) ([]byte, string, int, error) {
   161  	httpClient.lock.RLock()
   162  	defer httpClient.lock.RUnlock()
   163  
   164  	var namespace string
   165  
   166  	url := *httpClient.url
   167  	for idx, path := range route {
   168  		url = *url.JoinPath(path)
   169  
   170  		// we know that the second route argument is always the repo name.
   171  		// need it for caching tokens, it's not used in requests made to authz server.
   172  		if idx == 1 {
   173  			namespace = path
   174  		}
   175  	}
   176  
   177  	url.RawQuery = url.Query().Encode()
   178  	//nolint: bodyclose,contextcheck
   179  	resp, body, err := httpClient.makeAndDoRequest(http.MethodGet, mediaType, namespace, url.String())
   180  	if err != nil {
   181  		httpClient.log.Error().Err(err).Str("url", url.String()).Str("component", "sync").
   182  			Str("errorType", common.TypeOf(err)).
   183  			Msg("failed to make request")
   184  
   185  		return nil, "", -1, err
   186  	}
   187  
   188  	if resp.StatusCode != http.StatusOK {
   189  		return nil, "", resp.StatusCode, errors.New(string(body)) //nolint:goerr113
   190  	}
   191  
   192  	// read blob
   193  	if len(body) > 0 {
   194  		err = json.Unmarshal(body, &resultPtr)
   195  	}
   196  
   197  	return body, resp.Header.Get("Content-Type"), resp.StatusCode, err
   198  }
   199  
   200  func (httpClient *Client) getAuthType(resp *http.Response) {
   201  	authHeader := resp.Header.Get("www-authenticate")
   202  
   203  	authHeaderLower := strings.ToLower(authHeader)
   204  
   205  	//nolint: gocritic
   206  	if strings.Contains(authHeaderLower, "bearer") {
   207  		httpClient.authType = tokenAuth
   208  	} else if strings.Contains(authHeaderLower, "basic") {
   209  		httpClient.authType = basicAuth
   210  	} else {
   211  		httpClient.authType = noneAuth
   212  	}
   213  }
   214  
   215  func (httpClient *Client) setupAuth(req *http.Request, namespace string) error {
   216  	if httpClient.authType == tokenAuth {
   217  		token, err := httpClient.getToken(req.URL.String(), namespace)
   218  		if err != nil {
   219  			httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
   220  				Str("errorType", common.TypeOf(err)).
   221  				Msg("failed to get token from authorization realm")
   222  
   223  			return err
   224  		}
   225  
   226  		req.Header.Set("Authorization", "Bearer "+token.Token)
   227  	} else if httpClient.authType == basicAuth {
   228  		req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
   229  	}
   230  
   231  	return nil
   232  }
   233  
   234  func (httpClient *Client) get(ctx context.Context, url string, setAuth bool) (*http.Response, []byte, error) {
   235  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) //nolint
   236  	if err != nil {
   237  		return nil, nil, err
   238  	}
   239  
   240  	if setAuth && httpClient.config.Username != "" && httpClient.config.Password != "" {
   241  		req.SetBasicAuth(httpClient.config.Username, httpClient.config.Password)
   242  	}
   243  
   244  	return httpClient.doRequest(req)
   245  }
   246  
   247  func (httpClient *Client) doRequest(req *http.Request) (*http.Response, []byte, error) {
   248  	resp, err := httpClient.client.Do(req)
   249  	if err != nil {
   250  		httpClient.log.Error().Err(err).Str("url", req.URL.String()).Str("component", "sync").
   251  			Str("errorType", common.TypeOf(err)).
   252  			Msg("failed to make request")
   253  
   254  		return nil, nil, err
   255  	}
   256  
   257  	defer resp.Body.Close()
   258  
   259  	body, err := io.ReadAll(resp.Body)
   260  	if err != nil {
   261  		httpClient.log.Error().Err(err).Str("url", req.URL.String()).
   262  			Str("errorType", common.TypeOf(err)).
   263  			Msg("failed to read body")
   264  
   265  		return nil, nil, err
   266  	}
   267  
   268  	return resp, body, nil
   269  }
   270  
   271  func (httpClient *Client) makeAndDoRequest(method, mediaType, namespace, urlStr string,
   272  ) (*http.Response, []byte, error) {
   273  	req, err := http.NewRequest(method, urlStr, nil) //nolint
   274  	if err != nil {
   275  		return nil, nil, err
   276  	}
   277  
   278  	if err := httpClient.setupAuth(req, namespace); err != nil {
   279  		return nil, nil, err
   280  	}
   281  
   282  	if mediaType != "" {
   283  		req.Header.Set("Accept", mediaType)
   284  	}
   285  
   286  	resp, body, err := httpClient.doRequest(req)
   287  	if err != nil {
   288  		return nil, nil, err
   289  	}
   290  
   291  	// let's retry one time if we get an insufficient_scope error
   292  	if ok, challengeParams := needsRetryWithUpdatedScope(err, resp); ok {
   293  		var tokenURL *url.URL
   294  
   295  		var token *bearerToken
   296  
   297  		tokenURL, err = getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
   298  		if err != nil {
   299  			return nil, nil, err
   300  		}
   301  
   302  		token, err = httpClient.getTokenFromURL(tokenURL.String(), namespace)
   303  		if err != nil {
   304  			return nil, nil, err
   305  		}
   306  
   307  		req.Header.Set("Authorization", "Bearer "+token.Token)
   308  
   309  		resp, body, err = httpClient.doRequest(req)
   310  	}
   311  
   312  	return resp, body, err
   313  }
   314  
   315  func (httpClient *Client) getTokenFromURL(urlStr, namespace string) (*bearerToken, error) {
   316  	//nolint: bodyclose
   317  	resp, body, err := httpClient.get(context.Background(), urlStr, true)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	if resp.StatusCode != http.StatusOK {
   323  		return nil, zerr.ErrUnauthorizedAccess
   324  	}
   325  
   326  	token, err := newBearerToken(body)
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  
   331  	// cache it
   332  	httpClient.cache.Set(namespace, token)
   333  
   334  	return token, nil
   335  }
   336  
   337  // Gets bearer token from Authorization realm.
   338  func (httpClient *Client) getToken(urlStr, namespace string) (*bearerToken, error) {
   339  	// first check cache
   340  	token := httpClient.cache.Get(namespace)
   341  	if token != nil && !token.isExpired() {
   342  		return token, nil
   343  	}
   344  
   345  	//nolint: bodyclose
   346  	resp, _, err := httpClient.get(context.Background(), urlStr, false)
   347  	if err != nil {
   348  		return nil, err
   349  	}
   350  
   351  	challengeParams, err := parseAuthHeader(resp)
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  
   356  	tokenURL, err := getTokenURLFromChallengeParams(challengeParams, httpClient.config.Username)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  
   361  	return httpClient.getTokenFromURL(tokenURL.String(), namespace)
   362  }
   363  
   364  func newBearerToken(blob []byte) (*bearerToken, error) {
   365  	token := new(bearerToken)
   366  	if err := json.Unmarshal(blob, &token); err != nil {
   367  		return nil, err
   368  	}
   369  
   370  	if token.Token == "" {
   371  		token.Token = token.AccessToken
   372  	}
   373  
   374  	if token.ExpiresIn < minimumTokenLifetimeSeconds {
   375  		token.ExpiresIn = minimumTokenLifetimeSeconds
   376  	}
   377  
   378  	if token.IssuedAt.IsZero() {
   379  		token.IssuedAt = time.Now().UTC()
   380  	}
   381  
   382  	token.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second)
   383  
   384  	return token, nil
   385  }
   386  
   387  func getTokenURLFromChallengeParams(params challengeParams, account string) (*url.URL, error) {
   388  	parsedRealm, err := url.Parse(params.realm)
   389  	if err != nil {
   390  		return nil, err
   391  	}
   392  
   393  	query := parsedRealm.Query()
   394  	query.Set("service", params.service)
   395  	query.Set("scope", params.scope)
   396  
   397  	if account != "" {
   398  		query.Set("account", account)
   399  	}
   400  
   401  	parsedRealm.RawQuery = query.Encode()
   402  
   403  	return parsedRealm, nil
   404  }
   405  
   406  func parseAuthHeader(resp *http.Response) (challengeParams, error) {
   407  	authHeader := resp.Header.Get("www-authenticate")
   408  
   409  	authHeaderSlice := strings.Split(authHeader, ",")
   410  
   411  	params := challengeParams{}
   412  
   413  	for _, elem := range authHeaderSlice {
   414  		if strings.Contains(strings.ToLower(elem), "bearer") {
   415  			elem = strings.Split(elem, " ")[1]
   416  		}
   417  
   418  		elem := strings.ReplaceAll(elem, "\"", "")
   419  
   420  		elemSplit := strings.Split(elem, "=")
   421  		if len(elemSplit) != 2 { //nolint: gomnd
   422  			return params, zerr.ErrParsingAuthHeader
   423  		}
   424  
   425  		authKey := elemSplit[0]
   426  
   427  		authValue := elemSplit[1]
   428  
   429  		switch authKey {
   430  		case "realm":
   431  			params.realm = authValue
   432  		case "service":
   433  			params.service = authValue
   434  		case "scope":
   435  			params.scope = authValue
   436  		case "error":
   437  			params.err = authValue
   438  		}
   439  	}
   440  
   441  	return params, nil
   442  }
   443  
   444  // Checks if the auth headers in the response contain an indication of a failed
   445  // authorization because of an "insufficient_scope" error.
   446  func needsRetryWithUpdatedScope(err error, resp *http.Response) (bool, challengeParams) {
   447  	params := challengeParams{}
   448  	if err == nil && resp.StatusCode == http.StatusUnauthorized {
   449  		params, err = parseAuthHeader(resp)
   450  		if err != nil {
   451  			return false, params
   452  		}
   453  
   454  		if params.err == "insufficient_scope" {
   455  			if params.scope != "" {
   456  				return true, params
   457  			}
   458  		}
   459  	}
   460  
   461  	return false, params
   462  }