github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/docker/registry/internal/transports.go (about)

     1  // Copyright 2021 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package internal
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/docker/distribution/registry/client/auth/challenge"
    16  	"github.com/juju/errors"
    17  )
    18  
    19  type dynamicTransportFunc func() (http.RoundTripper, error)
    20  
    21  // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request.
    22  func (f dynamicTransportFunc) RoundTrip(req *http.Request) (*http.Response, error) {
    23  	transport, err := f()
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  	return transport.RoundTrip(req)
    28  }
    29  
    30  type challengeTransport struct {
    31  	baseTransport    http.RoundTripper
    32  	currentTransport http.RoundTripper
    33  
    34  	username  string
    35  	password  string
    36  	authToken string
    37  }
    38  
    39  func newChallengeTransport(
    40  	transport http.RoundTripper, username string, password string, authToken string,
    41  ) http.RoundTripper {
    42  	return &challengeTransport{
    43  		baseTransport: transport,
    44  		username:      username,
    45  		password:      password,
    46  		authToken:     authToken,
    47  	}
    48  }
    49  
    50  func (t *challengeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    51  	transport := t.baseTransport
    52  	if t.currentTransport != nil {
    53  		transport = t.currentTransport
    54  	}
    55  	resp, err := transport.RoundTrip(req)
    56  	if err != nil {
    57  		return nil, errors.Trace(err)
    58  	}
    59  	originalResp := resp
    60  	if !isUnauthorizedResponse(originalResp) {
    61  		return resp, nil
    62  	}
    63  	for _, c := range challenge.ResponseChallenges(originalResp) {
    64  		if err != nil {
    65  			logger.Warningf("authentication failed: %s", err.Error())
    66  			err = nil
    67  		}
    68  		switch strings.ToLower(c.Scheme) {
    69  		case "bearer":
    70  			tokenTransport := &tokenTransport{
    71  				transport: t.baseTransport,
    72  				username:  t.password,
    73  				password:  t.password,
    74  				authToken: t.authToken,
    75  			}
    76  			err = tokenTransport.refreshOAuthToken(originalResp)
    77  			if err != nil {
    78  				continue
    79  			}
    80  			transport = tokenTransport
    81  		case "basic":
    82  			transport = newBasicTransport(t.baseTransport, t.username, t.password, t.authToken)
    83  		default:
    84  			err = fmt.Errorf("unknown WWW-Authenticate challenge scheme: %s", c.Scheme)
    85  			continue
    86  		}
    87  		resp, err = transport.RoundTrip(req)
    88  		if err == nil && !isUnauthorizedResponse(resp) {
    89  			t.currentTransport = transport
    90  			return resp, nil
    91  		}
    92  	}
    93  	if err != nil {
    94  		return nil, errors.Trace(err)
    95  	}
    96  	if t.password == "" && t.authToken == "" {
    97  		return nil, errors.NewUnauthorized(err, "authorization is required for a private registry")
    98  	}
    99  	return resp, nil
   100  }
   101  
   102  type basicTransport struct {
   103  	transport http.RoundTripper
   104  	username  string
   105  	password  string
   106  	authToken string
   107  }
   108  
   109  func newBasicTransport(
   110  	transport http.RoundTripper, username string, password string, authToken string,
   111  ) http.RoundTripper {
   112  	return &basicTransport{
   113  		transport: transport,
   114  		username:  username,
   115  		password:  password,
   116  		authToken: authToken,
   117  	}
   118  }
   119  
   120  func (basicTransport) scheme() string {
   121  	return "Basic"
   122  }
   123  
   124  func (t basicTransport) authorizeRequest(req *http.Request) error {
   125  	if t.authToken != "" {
   126  		req.Header.Set("Authorization", fmt.Sprintf("%s %s", t.scheme(), t.authToken))
   127  		return nil
   128  	}
   129  	if t.username != "" || t.password != "" {
   130  		req.SetBasicAuth(t.username, t.password)
   131  	}
   132  	return nil
   133  }
   134  
   135  // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request.
   136  func (t basicTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   137  	if err := t.authorizeRequest(req); err != nil {
   138  		return nil, errors.Trace(err)
   139  	}
   140  	resp, err := t.transport.RoundTrip(req)
   141  	logger.Tracef("basicTransport %q, resp.Header => %#v, %q", req.URL, resp.Header, resp.Status)
   142  	return resp, errors.Trace(err)
   143  }
   144  
   145  type tokenTransport struct {
   146  	transport       http.RoundTripper
   147  	username        string
   148  	password        string
   149  	authToken       string
   150  	oauthToken      string
   151  	reuseOAuthToken bool
   152  }
   153  
   154  func newTokenTransport(
   155  	transport http.RoundTripper, username, password, authToken, oauthToken string, reuseOAuthToken bool,
   156  ) http.RoundTripper {
   157  	return &tokenTransport{
   158  		transport:       transport,
   159  		username:        username,
   160  		password:        password,
   161  		authToken:       authToken,
   162  		oauthToken:      oauthToken,
   163  		reuseOAuthToken: reuseOAuthToken,
   164  	}
   165  }
   166  
   167  func (tokenTransport) scheme() string {
   168  	return "Bearer"
   169  }
   170  
   171  func getChallengeParameters(scheme string, resp *http.Response) map[string]string {
   172  	logger.Tracef(
   173  		"getting chanllenge parametter for %q with scheme %q from %q",
   174  		resp.Request.URL.String(),
   175  		scheme, resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")],
   176  	)
   177  	for _, c := range challenge.ResponseChallenges(resp) {
   178  		if strings.EqualFold(c.Scheme, scheme) {
   179  			return c.Parameters
   180  		}
   181  	}
   182  	logger.Tracef("failed to get challenge parameters for %q schema -> %v", scheme, resp.Header)
   183  	return nil
   184  }
   185  
   186  type tokenResponse struct {
   187  	Token        string    `json:"token"`
   188  	AccessToken  string    `json:"access_token"`
   189  	RefreshToken string    `json:"refresh_token"`
   190  	ExpiresIn    int       `json:"expires_in"`
   191  	IssuedAt     time.Time `json:"issued_at"`
   192  	Scope        string    `json:"scope"`
   193  }
   194  
   195  func (t tokenResponse) token() string {
   196  	if t.AccessToken != "" {
   197  		return t.AccessToken
   198  	}
   199  	if t.Token != "" {
   200  		return t.Token
   201  	}
   202  	return ""
   203  }
   204  
   205  func (t *tokenTransport) refreshOAuthToken(failedResp *http.Response) error {
   206  	parameters := getChallengeParameters(t.scheme(), failedResp)
   207  	if len(parameters) == 0 {
   208  		return errors.NewForbidden(nil, "failed to refresh bearer token")
   209  	}
   210  	realm, ok := parameters["realm"]
   211  	if !ok {
   212  		return errors.New("no realm specified for token auth challenge")
   213  	}
   214  	service, ok := parameters["service"]
   215  	if !ok {
   216  		return errors.New("no service specified for token auth challenge")
   217  	}
   218  	scope, ok := parameters["scope"]
   219  	if !ok {
   220  		logger.Tracef("no scope specified for token auth challenge")
   221  	}
   222  
   223  	url, err := url.Parse(realm)
   224  	if err != nil {
   225  		return errors.Trace(err)
   226  	}
   227  	q := url.Query()
   228  	if scope != "" {
   229  		q.Set("scope", scope)
   230  	}
   231  	q.Set("service", service)
   232  	url.RawQuery = q.Encode()
   233  
   234  	request, err := http.NewRequest("GET", url.String(), nil)
   235  	if err != nil {
   236  		return errors.Trace(err)
   237  	}
   238  	tokenRefreshTransport := newBasicTransport(t.transport, t.username, t.password, t.authToken)
   239  	resp, err := tokenRefreshTransport.RoundTrip(request)
   240  	if err != nil {
   241  		return errors.Trace(err)
   242  	}
   243  	if resp.StatusCode != http.StatusOK {
   244  		_, err = handleErrorResponse(resp)
   245  		return errors.Trace(err)
   246  	}
   247  
   248  	decoder := json.NewDecoder(resp.Body)
   249  	var tr tokenResponse
   250  	if err = decoder.Decode(&tr); err != nil {
   251  		return fmt.Errorf("unable to decode token response: %s", err)
   252  	}
   253  	t.oauthToken = tr.token()
   254  	return nil
   255  }
   256  
   257  func (t *tokenTransport) authorizeRequest(req *http.Request) error {
   258  	if t.oauthToken != "" {
   259  		req.Header.Set("Authorization", fmt.Sprintf("%s %s", t.scheme(), t.oauthToken))
   260  	}
   261  	return nil
   262  }
   263  
   264  // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request.
   265  func (t *tokenTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   266  	defer func() {
   267  		if !t.reuseOAuthToken {
   268  			// We usually do not re-use the OAuth token because each API call might have different scope.
   269  			// But some of the provider use long life token and there is no need to refresh.
   270  			t.oauthToken = ""
   271  		}
   272  	}()
   273  
   274  	if err := t.authorizeRequest(req); err != nil {
   275  		return nil, errors.Trace(err)
   276  	}
   277  	resp, err := t.transport.RoundTrip(req)
   278  	if err != nil {
   279  		return nil, errors.Trace(err)
   280  	}
   281  	if isUnauthorizedResponse(resp) {
   282  		// refresh token and retry.
   283  		return t.retry(req, resp)
   284  	}
   285  	return resp, errors.Trace(err)
   286  }
   287  
   288  func (t *tokenTransport) retry(req *http.Request, prevResp *http.Response) (*http.Response, error) {
   289  	logger.Tracef(
   290  		"retrying req URL %q, previous response header %#v, status %v",
   291  		req.URL, prevResp.Header, prevResp.Status,
   292  	)
   293  
   294  	if err := t.refreshOAuthToken(prevResp); err != nil {
   295  		return nil, errors.Annotatef(err, "refreshing OAuth token")
   296  	}
   297  	if err := t.authorizeRequest(req); err != nil {
   298  		return nil, errors.Trace(err)
   299  	}
   300  	resp, err := t.transport.RoundTrip(req)
   301  	if isUnauthorizedResponse(resp) {
   302  		if t.password == "" && t.authToken == "" {
   303  			return nil, errors.NewUnauthorized(err, "authorization is required for a private registry")
   304  		}
   305  	}
   306  	return resp, errors.Trace(err)
   307  }
   308  
   309  func isUnauthorizedResponse(resp *http.Response) bool {
   310  	return resp != nil && resp.StatusCode == http.StatusUnauthorized
   311  }
   312  
   313  type errorTransport struct {
   314  	transport http.RoundTripper
   315  }
   316  
   317  func newErrorTransport(transport http.RoundTripper) http.RoundTripper {
   318  	return &errorTransport{transport: transport}
   319  }
   320  
   321  // RoundTrip executes a single HTTP transaction, returning a Response for the provided Request.
   322  func (t errorTransport) RoundTrip(request *http.Request) (*http.Response, error) {
   323  	resp, err := t.transport.RoundTrip(request)
   324  	if err != nil {
   325  		return resp, errors.Trace(err)
   326  	}
   327  	if resp.StatusCode < 400 {
   328  		return resp, nil
   329  	}
   330  	logger.Tracef("errorTransport %q, err -> %v", request.URL, err)
   331  	return handleErrorResponse(resp)
   332  }
   333  
   334  func handleErrorResponse(resp *http.Response) (*http.Response, error) {
   335  	if resp.StatusCode < 400 {
   336  		return resp, nil
   337  	}
   338  	defer resp.Body.Close()
   339  	body, err := io.ReadAll(resp.Body)
   340  	if err != nil {
   341  		return nil, errors.Annotatef(err, "reading bad response body with status code %d", resp.StatusCode)
   342  	}
   343  	errMsg := fmt.Sprintf("non-successful response status=%d", resp.StatusCode)
   344  	if logger.IsTraceEnabled() {
   345  		logger.Tracef("%s, url %q, body=%q", errMsg, resp.Request.URL.String(), body)
   346  	}
   347  	errNew := errors.Errorf
   348  	switch resp.StatusCode {
   349  	case http.StatusForbidden:
   350  		errNew = errors.Forbiddenf
   351  	case http.StatusUnauthorized:
   352  		errNew = errors.Unauthorizedf
   353  	case http.StatusNotFound:
   354  		errNew = errors.NotFoundf
   355  	}
   356  	return nil, errNew(errMsg)
   357  }
   358  
   359  func unwrapNetError(err error) error {
   360  	if err == nil {
   361  		return nil
   362  	}
   363  	if neturlErr, ok := err.(*url.Error); ok {
   364  		return errors.Annotatef(neturlErr.Unwrap(), "%s %q", neturlErr.Op, neturlErr.URL)
   365  	}
   366  	return err
   367  }