cuelang.org/go@v0.13.0/mod/modregistrytest/registry.go (about)

     1  // Package modregistrytest provides helpers for testing packages
     2  // which interact with CUE registries.
     3  //
     4  // WARNING: THIS PACKAGE IS EXPERIMENTAL.
     5  // ITS API MAY CHANGE AT ANY TIME.
     6  package modregistrytest
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"encoding/json"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"io/fs"
    16  	"maps"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"net/url"
    20  	"regexp"
    21  	"slices"
    22  	"strings"
    23  
    24  	"cuelabs.dev/go/oci/ociregistry"
    25  	"cuelabs.dev/go/oci/ociregistry/ocifilter"
    26  	"cuelabs.dev/go/oci/ociregistry/ocimem"
    27  	"cuelabs.dev/go/oci/ociregistry/ociserver"
    28  	"golang.org/x/tools/txtar"
    29  
    30  	"cuelang.org/go/mod/modfile"
    31  	"cuelang.org/go/mod/modregistry"
    32  	"cuelang.org/go/mod/module"
    33  	"cuelang.org/go/mod/modzip"
    34  )
    35  
    36  // AuthConfig specifies authorization requirements for the server.
    37  type AuthConfig struct {
    38  	// Username and Password hold the basic auth credentials.
    39  	// If UseTokenServer is true, these apply to the token server
    40  	// rather than to the registry itself.
    41  	Username string `json:"username"`
    42  	Password string `json:"password"`
    43  
    44  	// BearerToken holds a bearer token to use as auth.
    45  	// If UseTokenServer is true, this applies to the token server
    46  	// rather than to the registry itself.
    47  	BearerToken string `json:"bearerToken"`
    48  
    49  	// UseTokenServer starts a token server and directs client
    50  	// requests to acquire auth tokens from that server.
    51  	UseTokenServer bool `json:"useTokenServer"`
    52  
    53  	// ACL holds the ACL for an authenticated client.
    54  	// If it's nil, the user is allowed full access.
    55  	// Note: there's only one ACL because we only
    56  	// support a single authenticated user.
    57  	ACL *ACL `json:"acl,omitempty"`
    58  
    59  	// Use401InsteadOf403 causes the server to send a 401
    60  	// response even when the credentials are present and correct.
    61  	Use401InsteadOf403 bool `json:"always401"`
    62  }
    63  
    64  // ACL determines what endpoints an authenticated user can accesse
    65  // Both Allow and Deny hold a list of regular expressions that
    66  // are matched against an HTTP request formatted as a string:
    67  //
    68  //	METHOD URL_PATH
    69  //
    70  // For example:
    71  //
    72  //	GET /v2/foo/bar
    73  type ACL struct {
    74  	// Allow holds the list of allowed paths for a user.
    75  	// If none match, the user is forbidden.
    76  	Allow []string
    77  	// Deny holds the list of denied paths for a user.
    78  	// If any match, the user is forbidden.
    79  	Deny []string
    80  }
    81  
    82  // Upload uploads the modules found inside fsys (stored
    83  // in the format described by [New]) to the given registry.
    84  func Upload(ctx context.Context, r ociregistry.Interface, fsys fs.FS) error {
    85  	_, err := upload(ctx, r, fsys)
    86  	return err
    87  }
    88  
    89  func upload(ctx context.Context, r ociregistry.Interface, fsys fs.FS) (authConfig []byte, err error) {
    90  	client := modregistry.NewClient(r)
    91  	mods, authConfigData, err := getModules(fsys)
    92  	if err != nil {
    93  		return nil, fmt.Errorf("invalid modules: %v", err)
    94  	}
    95  
    96  	if err := pushContent(ctx, client, mods); err != nil {
    97  		return nil, fmt.Errorf("cannot push modules: %v", err)
    98  	}
    99  	return authConfigData, nil
   100  }
   101  
   102  // New starts a registry instance that serves modules found inside fsys.
   103  // It serves the OCI registry protocol.
   104  // If prefix is non-empty, all module paths will be prefixed by that,
   105  // separated by a slash (/).
   106  //
   107  // Each module should be inside a directory named path_vers, where
   108  // slashes in path have been replaced with underscores and should
   109  // contain a cue.mod/module.cue file holding the module info.
   110  //
   111  // If there's a file named auth.json in the root directory,
   112  // it will cause access to the server to be gated by the
   113  // specified authorization. See the [AuthConfig] type for
   114  // details.
   115  //
   116  // The Registry should be closed after use.
   117  func New(fsys fs.FS, prefix string) (*Registry, error) {
   118  	r := ocimem.NewWithConfig(&ocimem.Config{ImmutableTags: true})
   119  
   120  	authConfigData, err := upload(context.Background(), ocifilter.Sub(r, prefix), fsys)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	var authConfig *AuthConfig
   125  	if authConfigData != nil {
   126  		if err := json.Unmarshal(authConfigData, &authConfig); err != nil {
   127  			return nil, fmt.Errorf("invalid auth.json: %v", err)
   128  		}
   129  	}
   130  	return NewServer(ocifilter.ReadOnly(r), authConfig)
   131  }
   132  
   133  // NewServer is like New except that instead of uploading
   134  // the contents of a filesystem, it just serves the contents
   135  // of the given registry guarded by the given auth configuration.
   136  // If auth is nil, no authentication will be required.
   137  func NewServer(r ociregistry.Interface, auth *AuthConfig) (*Registry, error) {
   138  	var tokenSrv *httptest.Server
   139  	if auth != nil && auth.UseTokenServer {
   140  		tokenSrv = httptest.NewServer(tokenHandler(auth))
   141  	}
   142  	r, err := authzRegistry(auth, r)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	srv := httptest.NewServer(&registryHandler{
   147  		auth:     auth,
   148  		registry: ociserver.New(r, nil),
   149  		tokenSrv: tokenSrv,
   150  	})
   151  	u, err := url.Parse(srv.URL)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	return &Registry{
   156  		srv:      srv,
   157  		host:     u.Host,
   158  		tokenSrv: tokenSrv,
   159  	}, nil
   160  }
   161  
   162  // authzRegistry wraps r by checking whether the client has authorization
   163  // to read any given repository.
   164  func authzRegistry(auth *AuthConfig, r ociregistry.Interface) (ociregistry.Interface, error) {
   165  	if auth == nil {
   166  		return r, nil
   167  	}
   168  	allow := func(repoName string) bool {
   169  		return true
   170  	}
   171  	if auth.ACL != nil {
   172  		allowCheck, err := regexpMatcher(auth.ACL.Allow)
   173  		if err != nil {
   174  			return nil, fmt.Errorf("invalid allow list: %v", err)
   175  		}
   176  		denyCheck, err := regexpMatcher(auth.ACL.Deny)
   177  		if err != nil {
   178  			return nil, fmt.Errorf("invalid deny list: %v", err)
   179  		}
   180  		allow = func(repoName string) bool {
   181  			return allowCheck(repoName) && !denyCheck(repoName)
   182  		}
   183  	}
   184  	return ocifilter.AccessChecker(r, func(repoName string, access ocifilter.AccessKind) (_err error) {
   185  		if !allow(repoName) {
   186  			if auth.Use401InsteadOf403 {
   187  				// TODO this response should be associated with a
   188  				// Www-Authenticate header, but this won't do that.
   189  				// Given that the ociauth logic _should_ turn
   190  				// this back into a 403 error again, perhaps
   191  				// we're OK.
   192  				return ociregistry.ErrUnauthorized
   193  			}
   194  			// TODO should we be a bit more sophisticated and only
   195  			// return ErrDenied when the repository doesn't exist?
   196  			return ociregistry.ErrDenied
   197  		}
   198  		return nil
   199  	}), nil
   200  }
   201  
   202  func regexpMatcher(patStrs []string) (func(string) bool, error) {
   203  	pats := make([]*regexp.Regexp, len(patStrs))
   204  	for i, s := range patStrs {
   205  		pat, err := regexp.Compile(s)
   206  		if err != nil {
   207  			return nil, fmt.Errorf("invalid regexp in ACL: %v", err)
   208  		}
   209  		pats[i] = pat
   210  	}
   211  	return func(name string) bool {
   212  		for _, pat := range pats {
   213  			if pat.MatchString(name) {
   214  				return true
   215  			}
   216  		}
   217  		return false
   218  	}, nil
   219  }
   220  
   221  type registryHandler struct {
   222  	auth     *AuthConfig
   223  	registry http.Handler
   224  	tokenSrv *httptest.Server
   225  }
   226  
   227  const (
   228  	registryAuthToken   = "ok-token-for-registrytest"
   229  	registryUnauthToken = "unauth-token-for-registrytest"
   230  )
   231  
   232  func (h *registryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   233  	if h.auth == nil {
   234  		h.registry.ServeHTTP(w, req)
   235  		return
   236  	}
   237  	if h.tokenSrv == nil {
   238  		h.serveDirectAuth(w, req)
   239  		return
   240  	}
   241  
   242  	// Auth with token server.
   243  	wwwAuth := fmt.Sprintf("Bearer realm=%q,service=registrytest", h.tokenSrv.URL)
   244  	authHeader := req.Header.Get("Authorization")
   245  	if authHeader == "" {
   246  		w.Header().Set("Www-Authenticate", wwwAuth)
   247  		writeError(w, ociregistry.ErrUnauthorized)
   248  		return
   249  	}
   250  	kind, token, ok := strings.Cut(authHeader, " ")
   251  	if !ok || kind != "Bearer" {
   252  		w.Header().Set("Www-Authenticate", wwwAuth)
   253  		writeError(w, ociregistry.ErrUnauthorized)
   254  		return
   255  	}
   256  	switch token {
   257  	case registryAuthToken:
   258  		// User is authorized.
   259  	case registryUnauthToken:
   260  		writeError(w, ociregistry.ErrDenied)
   261  		return
   262  	default:
   263  		// If we don't recognize the token, then presumably
   264  		// the client isn't authenticated so it's 401 not 403.
   265  		w.Header().Set("Www-Authenticate", wwwAuth)
   266  		writeError(w, ociregistry.ErrUnauthorized)
   267  		return
   268  	}
   269  	// If the underlying registry returns a 401 error,
   270  	// we need to add the Www-Authenticate header.
   271  	// As there's no way to get ociserver to do it,
   272  	// we hack it by wrapping the ResponseWriter
   273  	// with an implementation that does.
   274  	h.registry.ServeHTTP(&authHeaderWriter{
   275  		wwwAuth:        wwwAuth,
   276  		ResponseWriter: w,
   277  	}, req)
   278  }
   279  
   280  func (h *registryHandler) serveDirectAuth(w http.ResponseWriter, req *http.Request) {
   281  	auth := req.Header.Get("Authorization")
   282  	if auth == "" {
   283  		if h.auth.BearerToken != "" {
   284  			// Note that this lacks information like the realm,
   285  			// but we don't need it for our test cases yet.
   286  			w.Header().Set("Www-Authenticate", "Bearer service=registry")
   287  		} else {
   288  			w.Header().Set("Www-Authenticate", "Basic service=registry")
   289  		}
   290  		writeError(w, fmt.Errorf("%w: no credentials", ociregistry.ErrUnauthorized))
   291  		return
   292  	}
   293  	if h.auth.BearerToken != "" {
   294  		token, ok := strings.CutPrefix(auth, "Bearer ")
   295  		if !ok || token != h.auth.BearerToken {
   296  			writeError(w, fmt.Errorf("%w: invalid bearer credentials", ociregistry.ErrUnauthorized))
   297  			return
   298  		}
   299  	} else {
   300  		username, password, ok := req.BasicAuth()
   301  		if !ok || username != h.auth.Username || password != h.auth.Password {
   302  			writeError(w, fmt.Errorf("%w: invalid user-password credentials", ociregistry.ErrUnauthorized))
   303  			return
   304  		}
   305  	}
   306  	h.registry.ServeHTTP(w, req)
   307  }
   308  
   309  func tokenHandler(*AuthConfig) http.Handler {
   310  	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   311  		if req.Method != "POST" {
   312  			http.Error(w, "only POST supported", http.StatusMethodNotAllowed)
   313  			return
   314  		}
   315  		req.ParseForm()
   316  		if req.Form.Get("service") != "registrytest" {
   317  			http.Error(w, "invalid service", http.StatusBadRequest)
   318  			return
   319  		}
   320  		if req.Form.Get("grant_type") != "refresh_token" {
   321  			http.Error(w, "invalid grant type", http.StatusBadRequest)
   322  			return
   323  		}
   324  		refreshToken := req.Form.Get("refresh_token")
   325  		if refreshToken != "registrytest-refresh" {
   326  			http.Error(w, fmt.Sprintf("invalid refresh token %q", refreshToken), http.StatusForbidden)
   327  			return
   328  		}
   329  		// See ociauth.wireToken for the full JSON format.
   330  		data, _ := json.Marshal(map[string]string{
   331  			"token": registryAuthToken,
   332  		})
   333  		w.Header().Set("Content-Type", "application/json")
   334  		w.Write(data)
   335  	})
   336  }
   337  
   338  func writeError(w http.ResponseWriter, err error) {
   339  	data, httpStatus := ociregistry.MarshalError(err)
   340  	w.Header().Set("Content-Type", "application/json")
   341  	w.WriteHeader(httpStatus)
   342  	w.Write(data)
   343  }
   344  
   345  func pushContent(ctx context.Context, client *modregistry.Client, mods map[module.Version]*moduleContent) error {
   346  	pushed := make(map[module.Version]bool)
   347  	// Iterate over modules in deterministic order.
   348  	for _, v := range slices.SortedFunc(maps.Keys(mods), module.Version.Compare) {
   349  		err := visitDepthFirst(mods, v, func(v module.Version, m *moduleContent) error {
   350  			if pushed[v] {
   351  				return nil
   352  			}
   353  			var zipContent bytes.Buffer
   354  			if err := m.writeZip(&zipContent); err != nil {
   355  				return err
   356  			}
   357  			if err := client.PutModule(ctx, v, bytes.NewReader(zipContent.Bytes()), int64(zipContent.Len())); err != nil {
   358  				return err
   359  			}
   360  			pushed[v] = true
   361  			return nil
   362  		})
   363  		if err != nil {
   364  			return err
   365  		}
   366  	}
   367  	return nil
   368  }
   369  
   370  func visitDepthFirst(mods map[module.Version]*moduleContent, v module.Version, f func(module.Version, *moduleContent) error) error {
   371  	m := mods[v]
   372  	if m == nil {
   373  		return nil
   374  	}
   375  	for _, depv := range m.modFile.DepVersions() {
   376  		if err := visitDepthFirst(mods, depv, f); err != nil {
   377  			return err
   378  		}
   379  	}
   380  	return f(v, m)
   381  }
   382  
   383  type Registry struct {
   384  	srv      *httptest.Server
   385  	tokenSrv *httptest.Server
   386  	host     string
   387  }
   388  
   389  func (r *Registry) Close() {
   390  	r.srv.Close()
   391  	if r.tokenSrv != nil {
   392  		r.tokenSrv.Close()
   393  	}
   394  }
   395  
   396  type authHeaderWriter struct {
   397  	wwwAuth string
   398  	http.ResponseWriter
   399  }
   400  
   401  func (w *authHeaderWriter) WriteHeader(code int) {
   402  	if code == http.StatusUnauthorized && w.Header().Get("Www-Authenticate") == "" {
   403  		w.Header().Set("Www-Authenticate", w.wwwAuth)
   404  	}
   405  	w.ResponseWriter.WriteHeader(code)
   406  }
   407  
   408  // Host returns the hostname for the registry server;
   409  // for example localhost:13455.
   410  //
   411  // The connection can be assumed to be insecure.
   412  func (r *Registry) Host() string {
   413  	return r.host
   414  }
   415  
   416  func getModules(fsys fs.FS) (map[module.Version]*moduleContent, []byte, error) {
   417  	var authConfig []byte
   418  	modules := make(map[string]*moduleContent)
   419  	if err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
   420  		if err != nil {
   421  			// If a filesystem has no entries at all,
   422  			// return zero modules without an error.
   423  			if path == "." && errors.Is(err, fs.ErrNotExist) {
   424  				return fs.SkipAll
   425  			}
   426  			return err
   427  		}
   428  		if d.IsDir() {
   429  			return nil // we're only interested in regular files, not their parent directories
   430  		}
   431  		if path == "auth.json" {
   432  			authConfig, err = fs.ReadFile(fsys, path)
   433  			if err != nil {
   434  				return err
   435  			}
   436  			return nil
   437  		}
   438  		modver, rest, ok := strings.Cut(path, "/")
   439  		if !ok {
   440  			return fmt.Errorf("registry should only contain directories, but found regular file %q", path)
   441  		}
   442  		content := modules[modver]
   443  		if content == nil {
   444  			content = &moduleContent{}
   445  			modules[modver] = content
   446  		}
   447  		data, err := fs.ReadFile(fsys, path)
   448  		if err != nil {
   449  			return err
   450  		}
   451  		content.files = append(content.files, txtar.File{
   452  			Name: rest,
   453  			Data: data,
   454  		})
   455  		return nil
   456  	}); err != nil {
   457  		return nil, nil, err
   458  	}
   459  	for modver, content := range modules {
   460  		if err := content.init(modver); err != nil {
   461  			return nil, nil, fmt.Errorf("cannot initialize module %q: %v", modver, err)
   462  		}
   463  	}
   464  	byVer := map[module.Version]*moduleContent{}
   465  	for _, m := range modules {
   466  		byVer[m.version] = m
   467  	}
   468  	return byVer, authConfig, nil
   469  }
   470  
   471  type moduleContent struct {
   472  	version module.Version
   473  	files   []txtar.File
   474  	modFile *modfile.File
   475  }
   476  
   477  func (c *moduleContent) writeZip(w io.Writer) error {
   478  	return modzip.Create(w, c.version, c.files, txtarFileIO{})
   479  }
   480  
   481  func (c *moduleContent) init(versDir string) error {
   482  	found := false
   483  	for _, f := range c.files {
   484  		if f.Name != "cue.mod/module.cue" {
   485  			continue
   486  		}
   487  		modf, err := modfile.Parse(f.Data, f.Name)
   488  		if err != nil {
   489  			return err
   490  		}
   491  		if found {
   492  			return fmt.Errorf("multiple module.cue files")
   493  		}
   494  		mod := strings.ReplaceAll(modf.ModulePath(), "/", "_") + "_"
   495  		vers := strings.TrimPrefix(versDir, mod)
   496  		if len(vers) == len(versDir) {
   497  			return fmt.Errorf("module path %q in module.cue does not match directory %q", modf.QualifiedModule(), versDir)
   498  		}
   499  		v, err := module.NewVersion(modf.QualifiedModule(), vers)
   500  		if err != nil {
   501  			return fmt.Errorf("cannot make module version: %v", err)
   502  		}
   503  		c.version = v
   504  		c.modFile = modf
   505  		found = true
   506  	}
   507  	if !found {
   508  		return fmt.Errorf("no module.cue file found in %q", versDir)
   509  	}
   510  	return nil
   511  }