cuelabs.dev/go/oci/ociregistry@v0.0.0-20240906074133-82eb438dd565/ociauth/auth_test.go (about)

     1  package ociauth
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/url"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/go-quicktest/qt"
    18  )
    19  
    20  func TestBasicAuth(t *testing.T) {
    21  	ts := newTargetServer(t, func(req *http.Request) *httpError {
    22  		username, password, _ := req.BasicAuth()
    23  		if username != "testuser" || password != "testpassword" {
    24  			return &httpError{
    25  				statusCode: http.StatusUnauthorized,
    26  				header: http.Header{
    27  					"Www-Authenticate": {"Basic"},
    28  				},
    29  			}
    30  		}
    31  		return nil
    32  	})
    33  	client := &http.Client{
    34  		Transport: NewStdTransport(StdTransportParams{
    35  			Config: configFunc(func(host string) (ConfigEntry, error) {
    36  				if host != ts.Host {
    37  					return ConfigEntry{}, nil
    38  				}
    39  				return ConfigEntry{
    40  					Username: "testuser",
    41  					Password: "testpassword",
    42  				}, nil
    43  			}),
    44  		}),
    45  	}
    46  	assertRequest(context.Background(), t, ts, "/test", client, Scope{})
    47  }
    48  
    49  func TestBearerAuth(t *testing.T) {
    50  	testScope := ParseScope("repository:foo:push,pull")
    51  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
    52  		username, password, ok := req.BasicAuth()
    53  		if !ok || username != "testuser" || password != "testpassword" {
    54  			return nil, &httpError{
    55  				statusCode: http.StatusUnauthorized,
    56  			}
    57  		}
    58  		requestedScope := ParseScope(req.Form.Get("scope"))
    59  		if !runNonFatal(t, func(t testing.TB) {
    60  			qt.Assert(t, qt.DeepEquals(requestedScope, testScope))
    61  			qt.Assert(t, qt.DeepEquals(req.Form["service"], []string{"someService"}))
    62  		}) {
    63  			return nil, &httpError{
    64  				statusCode: http.StatusInternalServerError,
    65  			}
    66  		}
    67  		return &wireToken{
    68  			Token: token{requestedScope}.String(),
    69  		}, nil
    70  	})
    71  	ts := newTargetServer(t, func(req *http.Request) *httpError {
    72  		if req.Header.Get("Authorization") == "" {
    73  			return &httpError{
    74  				statusCode: http.StatusUnauthorized,
    75  				header: http.Header{
    76  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)},
    77  				},
    78  			}
    79  		}
    80  		runNonFatal(t, func(t testing.TB) {
    81  			qt.Assert(t, qt.DeepEquals(authScopeFromRequest(t, req), testScope))
    82  		})
    83  		return nil
    84  	})
    85  	client := &http.Client{
    86  		Transport: NewStdTransport(StdTransportParams{
    87  			Config: configFunc(func(host string) (ConfigEntry, error) {
    88  				if host != ts.Host {
    89  					return ConfigEntry{}, nil
    90  				}
    91  				return ConfigEntry{
    92  					Username: "testuser",
    93  					Password: "testpassword",
    94  				}, nil
    95  			}),
    96  		}),
    97  	}
    98  	assertRequest(context.Background(), t, ts, "/test", client, Scope{})
    99  }
   100  
   101  func TestBearerAuthAdditionalScope(t *testing.T) {
   102  	// This tests the scenario where there's a larger scope in the context
   103  	// than the required scope.
   104  	requiredScope := ParseScope("repository:foo:push,pull")
   105  	additionalScope := ParseScope("repository:bar:pull somethingElse")
   106  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
   107  		username, password, ok := req.BasicAuth()
   108  		if !ok || username != "testuser" || password != "testpassword" {
   109  			return nil, &httpError{
   110  				statusCode: http.StatusUnauthorized,
   111  			}
   112  		}
   113  		requestedScope := ParseScope(strings.Join(req.Form["scope"], " "))
   114  		if !runNonFatal(t, func(t testing.TB) {
   115  			qt.Assert(t, qt.DeepEquals(requestedScope, requiredScope.Union(additionalScope)))
   116  			qt.Assert(t, qt.DeepEquals(req.Form["service"], []string{"someService"}))
   117  		}) {
   118  		}
   119  		return &wireToken{
   120  			Token: token{requestedScope}.String(),
   121  		}, nil
   122  	})
   123  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   124  		if req.Header.Get("Authorization") == "" {
   125  			return &httpError{
   126  				statusCode: http.StatusUnauthorized,
   127  				header: http.Header{
   128  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, requiredScope)},
   129  				},
   130  			}
   131  		}
   132  		runNonFatal(t, func(t testing.TB) {
   133  			qt.Assert(t, qt.DeepEquals(authScopeFromRequest(t, req), requiredScope.Union(additionalScope)))
   134  		})
   135  		return nil
   136  	})
   137  	client := &http.Client{
   138  		Transport: NewStdTransport(StdTransportParams{
   139  			Config: configFunc(func(host string) (ConfigEntry, error) {
   140  				if host != ts.Host {
   141  					return ConfigEntry{}, nil
   142  				}
   143  				return ConfigEntry{
   144  					Username: "testuser",
   145  					Password: "testpassword",
   146  				}, nil
   147  			}),
   148  		}),
   149  	}
   150  	ctx := ContextWithScope(context.Background(), additionalScope)
   151  	assertRequest(ctx, t, ts, "/test", client, Scope{})
   152  }
   153  
   154  func TestBearerAuthRequiresExactScope(t *testing.T) {
   155  	// This tests the scenario where an auth server requires exactly the
   156  	// scope that was present in the challenge.
   157  	requiredScope := ParseScope("repository:foo:pull,push")
   158  	exactScope := "other repository:foo:push,pull"
   159  	exactScopeAsToken := base64.StdEncoding.EncodeToString([]byte("token-" + exactScope))
   160  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
   161  		username, password, ok := req.BasicAuth()
   162  		if !ok || username != "testuser" || password != "testpassword" {
   163  			return nil, &httpError{
   164  				statusCode: http.StatusUnauthorized,
   165  			}
   166  		}
   167  		requestedScope := strings.Join(req.Form["scope"], " ")
   168  		if requestedScope != exactScope {
   169  			return nil, &httpError{
   170  				statusCode: http.StatusUnauthorized,
   171  			}
   172  		}
   173  		return &wireToken{
   174  			Token: exactScopeAsToken,
   175  		}, nil
   176  	})
   177  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   178  		if req.Header.Get("Authorization") == "" {
   179  			return &httpError{
   180  				statusCode: http.StatusUnauthorized,
   181  				header: http.Header{
   182  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, exactScope)},
   183  				},
   184  			}
   185  		}
   186  		qt.Check(t, qt.Equals(req.Header.Get("Authorization"), "Bearer "+exactScopeAsToken))
   187  		return nil
   188  	})
   189  	client := &http.Client{
   190  		Transport: NewStdTransport(StdTransportParams{
   191  			Config: configFunc(func(host string) (ConfigEntry, error) {
   192  				if host != ts.Host {
   193  					return ConfigEntry{}, nil
   194  				}
   195  				return ConfigEntry{
   196  					Username: "testuser",
   197  					Password: "testpassword",
   198  				}, nil
   199  			}),
   200  		}),
   201  	}
   202  	assertRequest(context.Background(), t, ts, "/test", client, requiredScope)
   203  }
   204  
   205  func TestAuthNotAvailableAfterChallenge(t *testing.T) {
   206  	// This tests the scenario where the target server returns a challenge
   207  	// that we can't meet.
   208  	requestCount := 0
   209  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   210  		if req.Header.Get("Authorization") == "" {
   211  			requestCount++
   212  			return &httpError{
   213  				statusCode: http.StatusUnauthorized,
   214  				header: http.Header{
   215  					"Www-Authenticate": []string{"Basic service=someService"},
   216  				},
   217  			}
   218  		}
   219  		t.Errorf("authorization unexpectedly presented")
   220  		return nil
   221  	})
   222  	client := &http.Client{
   223  		Transport: NewStdTransport(StdTransportParams{
   224  			Config: configFunc(func(host string) (ConfigEntry, error) {
   225  				return ConfigEntry{}, nil
   226  			}),
   227  		}),
   228  	}
   229  	req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil)
   230  	qt.Assert(t, qt.IsNil(err))
   231  	resp, err := client.Do(req)
   232  	qt.Assert(t, qt.IsNil(err))
   233  	defer resp.Body.Close()
   234  	qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusUnauthorized))
   235  	qt.Check(t, qt.Equals(requestCount, 1))
   236  }
   237  
   238  func Test401ResponseWithJustAcquiredToken(t *testing.T) {
   239  	// This tests the scenario where a server returns a 401 response
   240  	// when the client has just successfully acquired a token from
   241  	// the auth server.
   242  	//
   243  	// In this case, a "correct" server should return
   244  	// either 403 (access to the resource is forbidden because the
   245  	// client's credentials are not sufficient) or 404 (either the
   246  	// repository really doesn't exist or the credentials are insufficient
   247  	// and the server doesn't allow clients to see whether repositories
   248  	// they don't have access to might exist).
   249  	//
   250  	// However, some real-world servers instead return a 401 response
   251  	// erroneously indicating that the client needs to acquire
   252  	// authorization credentials, even though they have in fact just
   253  	// done so.
   254  	//
   255  	// As a workaround for this case, we treat the response as a 404.
   256  
   257  	testScope := ParseScope("repository:foo:pull")
   258  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
   259  		requestedScope := ParseScope(req.Form.Get("scope"))
   260  		if !runNonFatal(t, func(t testing.TB) {
   261  			qt.Assert(t, qt.DeepEquals(requestedScope, testScope))
   262  			qt.Assert(t, qt.DeepEquals(req.Form["service"], []string{"someService"}))
   263  		}) {
   264  			return nil, &httpError{
   265  				statusCode: http.StatusInternalServerError,
   266  			}
   267  		}
   268  		return &wireToken{
   269  			Token: token{requestedScope}.String(),
   270  		}, nil
   271  	})
   272  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   273  		if req.Header.Get("Authorization") == "" {
   274  			return &httpError{
   275  				statusCode: http.StatusUnauthorized,
   276  				header: http.Header{
   277  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)},
   278  				},
   279  			}
   280  		}
   281  		if !runNonFatal(t, func(t testing.TB) {
   282  			qt.Assert(t, qt.DeepEquals(authScopeFromRequest(t, req), testScope))
   283  		}) {
   284  			return &httpError{
   285  				statusCode: http.StatusInternalServerError,
   286  			}
   287  		}
   288  		return &httpError{
   289  			statusCode: http.StatusUnauthorized,
   290  			header: http.Header{
   291  				"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, testScope)},
   292  			},
   293  		}
   294  	})
   295  	client := &http.Client{
   296  		Transport: NewStdTransport(StdTransportParams{
   297  			Config: configFunc(func(host string) (ConfigEntry, error) {
   298  				return ConfigEntry{}, nil
   299  			}),
   300  		}),
   301  	}
   302  	req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil)
   303  	qt.Assert(t, qt.IsNil(err))
   304  	resp, err := client.Do(req)
   305  	qt.Assert(t, qt.IsNil(err))
   306  	defer resp.Body.Close()
   307  	qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusForbidden))
   308  }
   309  
   310  func Test401ResponseWithNonAcquiredToken(t *testing.T) {
   311  	// This tests the scenario where a server returns a 401 response
   312  	// when the client has provided credentials already present in
   313  	// the configuration file.
   314  	//
   315  	// In this case, we don't want to trigger the fake-403-response
   316  	// behaviour test for in Test401ResponseWithJustAcquiredToken.
   317  
   318  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   319  		if req.Header.Get("Authorization") == "" {
   320  			return &httpError{
   321  				statusCode: http.StatusUnauthorized,
   322  				header: http.Header{
   323  					"Www-Authenticate": []string{"Basic"},
   324  				},
   325  				body: "no auth creds provided",
   326  			}
   327  		}
   328  		return &httpError{
   329  			statusCode: http.StatusUnauthorized,
   330  			header: http.Header{
   331  				"Www-Authenticate": []string{"Basic"},
   332  			},
   333  			body: "password mismatch",
   334  		}
   335  	})
   336  	client := &http.Client{
   337  		Transport: NewStdTransport(StdTransportParams{
   338  			Config: configFunc(func(host string) (ConfigEntry, error) {
   339  				return ConfigEntry{
   340  					Username: "someuser",
   341  					Password: "somepassword",
   342  				}, nil
   343  			}),
   344  		}),
   345  	}
   346  	req, err := http.NewRequestWithContext(context.Background(), "GET", ts.String()+"/test", nil)
   347  	qt.Assert(t, qt.IsNil(err))
   348  	resp, err := client.Do(req)
   349  	qt.Assert(t, qt.IsNil(err))
   350  	defer resp.Body.Close()
   351  	data, _ := io.ReadAll(resp.Body)
   352  	qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusUnauthorized))
   353  	qt.Assert(t, qt.Equals(string(data), "password mismatch"))
   354  }
   355  
   356  func TestConfigHasAccessToken(t *testing.T) {
   357  	accessToken := "somevalue"
   358  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   359  		if req.Header.Get("Authorization") == "" {
   360  			t.Errorf("no authorization presented")
   361  			return &httpError{
   362  				statusCode: http.StatusUnauthorized,
   363  			}
   364  		}
   365  		qt.Check(t, qt.Equals(req.Header.Get("Authorization"), "Bearer "+accessToken))
   366  		return nil
   367  	})
   368  	client := &http.Client{
   369  		Transport: NewStdTransport(StdTransportParams{
   370  			Config: configFunc(func(host string) (ConfigEntry, error) {
   371  				if host == ts.Host {
   372  					return ConfigEntry{
   373  						AccessToken: accessToken,
   374  					}, nil
   375  				}
   376  				return ConfigEntry{}, nil
   377  			}),
   378  		}),
   379  	}
   380  	assertRequest(context.Background(), t, ts, "/test", client, Scope{})
   381  }
   382  
   383  func TestConfigErrorNilRequestBody(t *testing.T) {
   384  	// stdTransport used to panic when given a nil request body
   385  	// if something failed before it called the underlying transport,
   386  	// as it would always try to close the body even when nil.
   387  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   388  		return &httpError{statusCode: http.StatusUnauthorized}
   389  	})
   390  	client := &http.Client{
   391  		Transport: NewStdTransport(StdTransportParams{
   392  			Config: configFunc(func(host string) (ConfigEntry, error) {
   393  				return ConfigEntry{}, fmt.Errorf("always fails")
   394  			}),
   395  		}),
   396  	}
   397  	_, err := client.Get(ts.String() + "/test")
   398  	qt.Assert(t, qt.ErrorMatches(err, `.*cannot acquire auth.*always fails`))
   399  }
   400  
   401  func TestLaterRequestCanUseEarlierTokenWithLargerScope(t *testing.T) {
   402  	authCount := 0
   403  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
   404  		authCount++
   405  		return &wireToken{
   406  			Token: token{ParseScope(strings.Join(req.Form["scope"], " "))}.String(),
   407  		}, nil
   408  	})
   409  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   410  		resource := strings.TrimPrefix(req.URL.Path, "/test/")
   411  		requiredScope := NewScope(ResourceScope{
   412  			ResourceType: TypeRepository,
   413  			Resource:     resource,
   414  			Action:       ActionPull,
   415  		})
   416  		if req.Header.Get("Authorization") == "" {
   417  			return &httpError{
   418  				statusCode: http.StatusUnauthorized,
   419  				header: http.Header{
   420  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, requiredScope)},
   421  				},
   422  			}
   423  		}
   424  		runNonFatal(t, func(t testing.TB) {
   425  			requestScope := authScopeFromRequest(t, req)
   426  			qt.Assert(t, qt.IsTrue(requestScope.Contains(requiredScope)), qt.Commentf("request scope: %q; required scope: %q", requestScope, requiredScope))
   427  		})
   428  		return nil
   429  	})
   430  	client := &http.Client{
   431  		Transport: NewStdTransport(StdTransportParams{
   432  			Config: configFunc(func(host string) (ConfigEntry, error) {
   433  				return ConfigEntry{}, nil
   434  			}),
   435  		}),
   436  	}
   437  	ctx := ContextWithScope(context.Background(), ParseScope("repository:foo1:pull repository:foo2:pull"))
   438  	assertRequest(ctx, t, ts, "/test/foo1", client, Scope{})
   439  	assertRequest(ctx, t, ts, "/test/foo2", client, Scope{})
   440  	// One token fetch should have been sufficient for both requests.
   441  	qt.Assert(t, qt.Equals(authCount, 1))
   442  }
   443  
   444  func TestAuthServerRejectsRequestsWithTooMuchScope(t *testing.T) {
   445  	// This tests the scenario described in the comment in registry.acquireAccessToken.
   446  	userHasScope := ParseScope("repository:foo:pull")
   447  
   448  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
   449  		requestedScope := ParseScope(strings.Join(req.Form["scope"], " "))
   450  		if !userHasScope.Contains(requestedScope) {
   451  			// Client is asking for more scope than the authenticated user
   452  			// has access to. Technically this should be OK, but some
   453  			// servers don't like it.
   454  			return nil, &httpError{
   455  				statusCode: http.StatusUnauthorized,
   456  			}
   457  		}
   458  		return &wireToken{
   459  			Token: token{requestedScope}.String(),
   460  		}, nil
   461  	})
   462  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   463  		requiredScope := ParseScope("repository:foo:pull")
   464  		if req.Header.Get("Authorization") == "" {
   465  			return &httpError{
   466  				statusCode: http.StatusUnauthorized,
   467  				header: http.Header{
   468  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, requiredScope)},
   469  				},
   470  			}
   471  		}
   472  		runNonFatal(t, func(t testing.TB) {
   473  			qt.Assert(t, qt.IsTrue(authScopeFromRequest(t, req).Contains(requiredScope)))
   474  		})
   475  		return nil
   476  	})
   477  	client := &http.Client{
   478  		Transport: NewStdTransport(StdTransportParams{
   479  			Config: configFunc(func(host string) (ConfigEntry, error) {
   480  				return ConfigEntry{}, nil
   481  			}),
   482  		}),
   483  	}
   484  	ctx := ContextWithScope(context.Background(), ParseScope("repository:foo:pull repository:bar:pull"))
   485  	assertRequest(ctx, t, ts, "/test", client, Scope{})
   486  }
   487  
   488  func TestAuthRequestUsesRefreshTokenFromConfig(t *testing.T) {
   489  	authCount := 0
   490  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
   491  		authCount++
   492  		if !runNonFatal(t, func(t testing.TB) {
   493  			qt.Assert(t, qt.Equals(req.Form.Get("grant_type"), "refresh_token"))
   494  			qt.Assert(t, qt.Not(qt.Equals(req.Form.Get("client_id"), "")))
   495  			qt.Assert(t, qt.Equals(req.Form.Get("service"), "someService"))
   496  			qt.Assert(t, qt.Equals(req.Form.Get("refresh_token"), "someRefreshToken"))
   497  		}) {
   498  			return nil, &httpError{
   499  				statusCode: http.StatusInternalServerError,
   500  			}
   501  		}
   502  		requestedScope := ParseScope(strings.Join(req.Form["scope"], " "))
   503  		// Return an access token that expires soon so that we can let it expire
   504  		// so the client will be forced to acquire a new one with the original
   505  		// refresh token.
   506  		return &wireToken{
   507  			Token:     token{requestedScope}.String(),
   508  			ExpiresIn: 2, // Two seconds from now.
   509  		}, nil
   510  	})
   511  	requiredScope := ParseScope("repository:foo:pull")
   512  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   513  		if req.Header.Get("Authorization") == "" {
   514  			return &httpError{
   515  				statusCode: http.StatusUnauthorized,
   516  				header: http.Header{
   517  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, requiredScope)},
   518  				},
   519  			}
   520  		}
   521  		runNonFatal(t, func(t testing.TB) {
   522  			qt.Assert(t, qt.IsTrue(authScopeFromRequest(t, req).Contains(requiredScope)))
   523  		})
   524  		return nil
   525  	})
   526  	client := &http.Client{
   527  		Transport: NewStdTransport(StdTransportParams{
   528  			Config: configFunc(func(host string) (ConfigEntry, error) {
   529  				if host == ts.Host {
   530  					return ConfigEntry{
   531  						RefreshToken: "someRefreshToken",
   532  					}, nil
   533  				}
   534  				return ConfigEntry{}, nil
   535  			}),
   536  		}),
   537  	}
   538  	assertRequest(context.Background(), t, ts, "/test", client, requiredScope)
   539  
   540  	// Let the original access token expire and then make another request,
   541  	// which should force the client to acquire another token using
   542  	// the original refresh token.
   543  
   544  	// Note: the expiry algorithm always leaves at least a second leeway.
   545  	time.Sleep(1100 * time.Millisecond)
   546  	assertRequest(context.Background(), t, ts, "/test", client, requiredScope)
   547  	// Check that it actually has had to acquire two tokens.
   548  	qt.Assert(t, qt.Equals(authCount, 2))
   549  }
   550  
   551  func TestAuthRequestUsesRefreshTokenFromAuthServer(t *testing.T) {
   552  	authCount := 0
   553  	authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
   554  		authCount++
   555  		if !runNonFatal(t, func(t testing.TB) {
   556  			// The client should be using a different refresh token each time
   557  			qt.Assert(t, qt.Equals(req.Form.Get("refresh_token"), fmt.Sprintf("someRefreshToken%d", authCount)))
   558  		}) {
   559  			return nil, &httpError{
   560  				statusCode: http.StatusInternalServerError,
   561  			}
   562  		}
   563  		requestedScope := ParseScope(strings.Join(req.Form["scope"], " "))
   564  		// Return an access token that expires soon so that we can let it expire
   565  		// so the client will be forced to acquire a new one with the original
   566  		// refresh token.
   567  		return &wireToken{
   568  			RefreshToken: fmt.Sprintf("someRefreshToken%d", authCount+1),
   569  			Token:        token{requestedScope}.String(),
   570  		}, nil
   571  	})
   572  	ts := newTargetServer(t, func(req *http.Request) *httpError {
   573  		resource := strings.TrimPrefix(req.URL.Path, "/test/")
   574  		requiredScope := NewScope(ResourceScope{
   575  			ResourceType: TypeRepository,
   576  			Resource:     resource,
   577  			Action:       ActionPull,
   578  		})
   579  		if req.Header.Get("Authorization") == "" {
   580  			return &httpError{
   581  				statusCode: http.StatusUnauthorized,
   582  				header: http.Header{
   583  					"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, requiredScope)},
   584  				},
   585  			}
   586  		}
   587  		runNonFatal(t, func(t testing.TB) {
   588  			requestScope := authScopeFromRequest(t, req)
   589  			qt.Assert(t, qt.IsTrue(requestScope.Contains(requiredScope)), qt.Commentf("request scope: %q; required scope: %q", requestScope, requiredScope))
   590  		})
   591  		return nil
   592  	})
   593  	client := &http.Client{
   594  		Transport: NewStdTransport(StdTransportParams{
   595  			Config: configFunc(func(host string) (ConfigEntry, error) {
   596  				if host == ts.Host {
   597  					return ConfigEntry{
   598  						RefreshToken: "someRefreshToken1",
   599  					}, nil
   600  				}
   601  				return ConfigEntry{}, nil
   602  			}),
   603  		}),
   604  	}
   605  	// Each time we make a new request, we'll be asking for a new scope
   606  	// because we're getting a new resource each time, so that will
   607  	// make another request to the auth server, which will return
   608  	// a new refresh token each time.
   609  	numRequests := 4
   610  	for i := 0; i < numRequests; i++ {
   611  		repo := fmt.Sprintf("foo%d", i)
   612  		assertRequest(context.Background(), t, ts, fmt.Sprintf("/test/foo%d", i), client, NewScope(ResourceScope{
   613  			ResourceType: TypeRepository,
   614  			Resource:     repo,
   615  			Action:       ActionPull,
   616  		}))
   617  	}
   618  	qt.Assert(t, qt.Equals(authCount, numRequests))
   619  }
   620  
   621  func assertRequest(ctx context.Context, t testing.TB, tsURL *url.URL, path string, client *http.Client, needScope Scope) {
   622  	ctx = ContextWithRequestInfo(ctx, RequestInfo{
   623  		RequiredScope: needScope,
   624  	})
   625  	// Try the request twice as the second time often exercises other
   626  	// code paths as caches are warmed up.
   627  	assertRequest1(ctx, t, tsURL, path, client)
   628  	assertRequest1(ctx, t, tsURL, path, client)
   629  }
   630  
   631  func assertRequest1(ctx context.Context, t testing.TB, tsURL *url.URL, path string, client *http.Client) {
   632  	req, err := http.NewRequestWithContext(ctx, "POST", tsURL.String()+path, strings.NewReader("test body"))
   633  	qt.Assert(t, qt.IsNil(err))
   634  	// Set ContentLength to -1 to prevent net/http from calling GetBody automatically,
   635  	// thus testing the GetBody-calling code inside registry.doRequest.
   636  	req.ContentLength = -1
   637  	resp, err := client.Do(req)
   638  	qt.Assert(t, qt.IsNil(err))
   639  	defer resp.Body.Close()
   640  	qt.Assert(t, qt.Equals(resp.StatusCode, http.StatusOK))
   641  	data, _ := io.ReadAll(resp.Body)
   642  	qt.Assert(t, qt.Equals(string(data), "test ok"))
   643  }
   644  
   645  // newAuthServer returns the URL for an auth server that uses auth to service authorization
   646  // requests. If that returns a nil *httpError, the first return parameter is marshaled
   647  // as a JSON response body; otherwise the error is returned.
   648  func newAuthServer(t *testing.T, auth func(req *http.Request) (any, *httpError)) *url.URL {
   649  	authSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   650  		t.Logf("-> authSrv %s %v {", req.Method, req.URL)
   651  		req.ParseForm()
   652  		bodyJSON, herr := auth(req)
   653  		if herr != nil {
   654  			herr.send(w)
   655  			t.Logf("} <- error %#v", herr)
   656  			return
   657  		}
   658  		w.Header().Set("Content-Type", "application/json")
   659  		w.WriteHeader(200)
   660  		data, err := json.Marshal(bodyJSON)
   661  		if err != nil {
   662  			panic(err)
   663  		}
   664  		w.Write(data)
   665  		t.Logf("} <- json %s", data)
   666  	}))
   667  	t.Cleanup(authSrv.Close)
   668  	return mustParseURL(authSrv.URL)
   669  }
   670  
   671  // newTargetServer returns the URL for a test target server that uses the targetGate
   672  // parameter to gate requests to the /test endpoint: if targetGate returns nil for a request
   673  // to that endpoint, the request will succeed.
   674  //
   675  // It also returns the URL for an auth server that uses auth to service authorization
   676  // requests. If that returns a nil *httpError, the first return parameter is marshaled
   677  // as a JSON response body; otherwise the error is returned.
   678  func newTargetServer(
   679  	t *testing.T,
   680  	targetGate func(req *http.Request) *httpError,
   681  ) *url.URL {
   682  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   683  		t.Logf("-> targetSrv %s %v auth=%q {", req.Method, req.URL, req.Header.Get("Authorization"))
   684  		herr := targetGate(req)
   685  		if herr != nil {
   686  			herr.send(w)
   687  			t.Logf("} <- error %#v", herr)
   688  			return
   689  		}
   690  		if req.URL.Path != "/test" && !strings.HasPrefix(req.URL.Path, "/test/") {
   691  			t.Logf("} <- error (wrong path)")
   692  			http.Error(w, "only /test is allowed", http.StatusNotFound)
   693  			return
   694  		}
   695  		if req.Method != "POST" {
   696  			t.Logf("} <- error (wrong method)")
   697  			http.Error(w, "only method POST is allowed", http.StatusMethodNotAllowed)
   698  			return
   699  		}
   700  		data, _ := io.ReadAll(req.Body)
   701  		if gotBody := string(data); gotBody != "test body" {
   702  			t.Logf("} <- error (wrong body %q)", gotBody)
   703  			http.Error(w, "wrong body", http.StatusForbidden)
   704  			return
   705  		}
   706  		t.Logf("} <- OK")
   707  		w.Write([]byte("test ok"))
   708  	}))
   709  	t.Cleanup(srv.Close)
   710  	return mustParseURL(srv.URL)
   711  }
   712  
   713  func mustParseURL(s string) *url.URL {
   714  	u, err := url.Parse(s)
   715  	if err != nil {
   716  		panic(err)
   717  	}
   718  	return u
   719  }
   720  
   721  type httpError struct {
   722  	header     http.Header
   723  	statusCode int
   724  	body       string
   725  }
   726  
   727  func (e *httpError) send(w http.ResponseWriter) {
   728  	for k, v := range e.header {
   729  		w.Header()[k] = v
   730  	}
   731  	w.WriteHeader(e.statusCode)
   732  	w.Write([]byte(e.body))
   733  }
   734  
   735  type configFunc func(host string) (ConfigEntry, error)
   736  
   737  func (f configFunc) EntryForRegistry(host string) (ConfigEntry, error) {
   738  	return f(host)
   739  }
   740  
   741  type token struct {
   742  	scope Scope
   743  }
   744  
   745  func authScopeFromRequest(t testing.TB, req *http.Request) Scope {
   746  	h, ok := req.Header["Authorization"]
   747  	if !ok {
   748  		t.Fatal("no Authorization found in request")
   749  	}
   750  	if len(h) != 1 {
   751  		t.Fatal("multiple Authorization headers found in request")
   752  	}
   753  	tokStr, ok := strings.CutPrefix(h[0], "Bearer ")
   754  	if !ok {
   755  		t.Fatalf("token %q is not bearer token", h)
   756  	}
   757  	tok, err := parseToken(tokStr)
   758  	qt.Assert(t, qt.IsNil(err))
   759  	return tok.scope
   760  }
   761  
   762  func parseToken(s string) (token, error) {
   763  	data, err := base64.StdEncoding.DecodeString(s)
   764  	if err != nil {
   765  		return token{}, fmt.Errorf("invalid token %q: %v", s, err)
   766  	}
   767  	scope, ok := strings.CutPrefix(string(data), "token-")
   768  	if !ok {
   769  		return token{}, fmt.Errorf("invalid token prefix")
   770  	}
   771  	return token{
   772  		scope: ParseScope(scope),
   773  	}, nil
   774  }
   775  
   776  func (tok token) String() string {
   777  	return base64.StdEncoding.EncodeToString([]byte("token-" + tok.scope.String()))
   778  }
   779  
   780  // runNonFatal runs the given function within t
   781  // but will not call Fatal on t even if Fatal is called
   782  // on the t passed to f. It reports whether all
   783  // checks succeeded.
   784  //
   785  // This makes it suitable for passing to assertion-based
   786  // functions inside goroutines where it's not ok to
   787  // call Fatal.
   788  func runNonFatal(t *testing.T, f func(t testing.TB)) (ok bool) {
   789  	defer func() {
   790  		switch e := recover(); e {
   791  		case errFailNow, errSkipNow:
   792  			ok = false
   793  		case nil:
   794  		default:
   795  			panic(e)
   796  		}
   797  	}()
   798  	f(nonFatalT{t})
   799  	return !t.Failed()
   800  }
   801  
   802  var (
   803  	errFailNow = errors.New("failing now")
   804  	errSkipNow = errors.New("skipping now")
   805  )
   806  
   807  type nonFatalT struct {
   808  	*testing.T
   809  }
   810  
   811  func (t nonFatalT) FailNow() {
   812  	t.Helper()
   813  	t.Fail()
   814  	panic(errFailNow)
   815  }
   816  
   817  func (t nonFatalT) Fatal(args ...any) {
   818  	t.Helper()
   819  	t.Error(args...)
   820  	t.FailNow()
   821  }
   822  
   823  func (t nonFatalT) Fatalf(format string, args ...any) {
   824  	t.Helper()
   825  	t.Errorf(format, args...)
   826  	t.FailNow()
   827  }
   828  
   829  func (t nonFatalT) Skip(args ...any) {
   830  	t.Helper()
   831  	t.Log(args...)
   832  	t.SkipNow()
   833  }
   834  
   835  func (t nonFatalT) SkipNow() {
   836  	panic(errSkipNow)
   837  }
   838  
   839  func (t nonFatalT) Skipf(format string, args ...any) {
   840  	t.Helper()
   841  	t.Logf(format, args...)
   842  	t.SkipNow()
   843  }