cuelang.org/go@v0.10.1/internal/registrytest/registry.go (about)

     1  package registrytest
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/fs"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/url"
    14  	"regexp"
    15  	"strings"
    16  
    17  	"cuelabs.dev/go/oci/ociregistry"
    18  	"cuelabs.dev/go/oci/ociregistry/ocifilter"
    19  	"cuelabs.dev/go/oci/ociregistry/ocimem"
    20  	"cuelabs.dev/go/oci/ociregistry/ociserver"
    21  	"golang.org/x/tools/txtar"
    22  
    23  	"cuelang.org/go/mod/modfile"
    24  	"cuelang.org/go/mod/modregistry"
    25  	"cuelang.org/go/mod/module"
    26  	"cuelang.org/go/mod/modzip"
    27  )
    28  
    29  // AuthConfig specifies authorization requirements for the server.
    30  type AuthConfig struct {
    31  	// Username and Password hold the basic auth credentials.
    32  	// If UseTokenServer is true, these apply to the token server
    33  	// rather than to the registry itself.
    34  	Username string `json:"username"`
    35  	Password string `json:"password"`
    36  
    37  	// BearerToken holds a bearer token to use as auth.
    38  	// If UseTokenServer is true, this applies to the token server
    39  	// rather than to the registry itself.
    40  	BearerToken string `json:"bearerToken"`
    41  
    42  	// UseTokenServer starts a token server and directs client
    43  	// requests to acquire auth tokens from that server.
    44  	UseTokenServer bool `json:"useTokenServer"`
    45  
    46  	// ACL holds the ACL for an authenticated client.
    47  	// If it's nil, the user is allowed full access.
    48  	// Note: there's only one ACL because we only
    49  	// support a single authenticated user.
    50  	ACL *ACL `json:"acl,omitempty"`
    51  
    52  	// Use401InsteadOf403 causes the server to send a 401
    53  	// response even when the credentials are present and correct.
    54  	Use401InsteadOf403 bool `json:"always401"`
    55  }
    56  
    57  // ACL determines what endpoints an authenticated user can accesse
    58  // Both Allow and Deny hold a list of regular expressions that
    59  // are matched against an HTTP request formatted as a string:
    60  //
    61  //	METHOD URL_PATH
    62  //
    63  // For example:
    64  //
    65  //	GET /v2/foo/bar
    66  type ACL struct {
    67  	// Allow holds the list of allowed paths for a user.
    68  	// If none match, the user is forbidden.
    69  	Allow []string
    70  	// Deny holds the list of denied paths for a user.
    71  	// If any match, the user is forbidden.
    72  	Deny []string
    73  }
    74  
    75  // Upload uploads the modules found inside fsys (stored
    76  // in the format described by [New]) to the given registry.
    77  func Upload(ctx context.Context, r ociregistry.Interface, fsys fs.FS) error {
    78  	_, err := upload(ctx, r, fsys)
    79  	return err
    80  }
    81  
    82  func upload(ctx context.Context, r ociregistry.Interface, fsys fs.FS) (authConfig []byte, err error) {
    83  	client := modregistry.NewClient(r)
    84  	mods, authConfigData, err := getModules(fsys)
    85  	if err != nil {
    86  		return nil, fmt.Errorf("invalid modules: %v", err)
    87  	}
    88  
    89  	if err := pushContent(ctx, client, mods); err != nil {
    90  		return nil, fmt.Errorf("cannot push modules: %v", err)
    91  	}
    92  	return authConfigData, nil
    93  }
    94  
    95  // New starts a registry instance that serves modules found inside fsys.
    96  // It serves the OCI registry protocol.
    97  // If prefix is non-empty, all module paths will be prefixed by that,
    98  // separated by a slash (/).
    99  //
   100  // Each module should be inside a directory named path_vers, where
   101  // slashes in path have been replaced with underscores and should
   102  // contain a cue.mod/module.cue file holding the module info.
   103  //
   104  // If there's a file named auth.json in the root directory,
   105  // it will cause access to the server to be gated by the
   106  // specified authorization. See the [AuthConfig] type for
   107  // details.
   108  //
   109  // The Registry should be closed after use.
   110  func New(fsys fs.FS, prefix string) (*Registry, error) {
   111  	r := ocimem.NewWithConfig(&ocimem.Config{ImmutableTags: true})
   112  
   113  	authConfigData, err := upload(context.Background(), ocifilter.Sub(r, prefix), fsys)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	var authConfig *AuthConfig
   118  	if authConfigData != nil {
   119  		if err := json.Unmarshal(authConfigData, &authConfig); err != nil {
   120  			return nil, fmt.Errorf("invalid auth.json: %v", err)
   121  		}
   122  	}
   123  	return NewServer(ocifilter.ReadOnly(r), authConfig)
   124  }
   125  
   126  // NewServer is like New except that instead of uploading
   127  // the contents of a filesystem, it just serves the contents
   128  // of the given registry guarded by the given auth configuration.
   129  // If auth is nil, no authentication will be required.
   130  func NewServer(r ociregistry.Interface, auth *AuthConfig) (*Registry, error) {
   131  	var tokenSrv *httptest.Server
   132  	if auth != nil && auth.UseTokenServer {
   133  		tokenSrv = httptest.NewServer(tokenHandler(auth))
   134  	}
   135  	r, err := authzRegistry(auth, r)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	srv := httptest.NewServer(&registryHandler{
   140  		auth:     auth,
   141  		registry: ociserver.New(r, nil),
   142  		tokenSrv: tokenSrv,
   143  	})
   144  	u, err := url.Parse(srv.URL)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	return &Registry{
   149  		srv:      srv,
   150  		host:     u.Host,
   151  		tokenSrv: tokenSrv,
   152  	}, nil
   153  }
   154  
   155  // authzRegistry wraps r by checking whether the client has authorization
   156  // to read any given repository.
   157  func authzRegistry(auth *AuthConfig, r ociregistry.Interface) (ociregistry.Interface, error) {
   158  	if auth == nil {
   159  		return r, nil
   160  	}
   161  	allow := func(repoName string) bool {
   162  		return true
   163  	}
   164  	if auth.ACL != nil {
   165  		allowCheck, err := regexpMatcher(auth.ACL.Allow)
   166  		if err != nil {
   167  			return nil, fmt.Errorf("invalid allow list: %v", err)
   168  		}
   169  		denyCheck, err := regexpMatcher(auth.ACL.Deny)
   170  		if err != nil {
   171  			return nil, fmt.Errorf("invalid deny list: %v", err)
   172  		}
   173  		allow = func(repoName string) bool {
   174  			return allowCheck(repoName) && !denyCheck(repoName)
   175  		}
   176  	}
   177  	return ocifilter.AccessChecker(r, func(repoName string, access ocifilter.AccessKind) (_err error) {
   178  		if !allow(repoName) {
   179  			if auth.Use401InsteadOf403 {
   180  				// TODO this response should be associated with a
   181  				// Www-Authenticate header, but this won't do that.
   182  				// Given that the ociauth logic _should_ turn
   183  				// this back into a 403 error again, perhaps
   184  				// we're OK.
   185  				return ociregistry.ErrUnauthorized
   186  			}
   187  			// TODO should we be a bit more sophisticated and only
   188  			// return ErrDenied when the repository doesn't exist?
   189  			return ociregistry.ErrDenied
   190  		}
   191  		return nil
   192  	}), nil
   193  }
   194  
   195  func regexpMatcher(patStrs []string) (func(string) bool, error) {
   196  	pats := make([]*regexp.Regexp, len(patStrs))
   197  	for i, s := range patStrs {
   198  		pat, err := regexp.Compile(s)
   199  		if err != nil {
   200  			return nil, fmt.Errorf("invalid regexp in ACL: %v", err)
   201  		}
   202  		pats[i] = pat
   203  	}
   204  	return func(name string) bool {
   205  		for _, pat := range pats {
   206  			if pat.MatchString(name) {
   207  				return true
   208  			}
   209  		}
   210  		return false
   211  	}, nil
   212  }
   213  
   214  type registryHandler struct {
   215  	auth     *AuthConfig
   216  	registry http.Handler
   217  	tokenSrv *httptest.Server
   218  }
   219  
   220  const (
   221  	registryAuthToken   = "ok-token-for-registrytest"
   222  	registryUnauthToken = "unauth-token-for-registrytest"
   223  )
   224  
   225  func (h *registryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   226  	if h.auth == nil {
   227  		h.registry.ServeHTTP(w, req)
   228  		return
   229  	}
   230  	if h.tokenSrv == nil {
   231  		h.serveDirectAuth(w, req)
   232  		return
   233  	}
   234  
   235  	// Auth with token server.
   236  	wwwAuth := fmt.Sprintf("Bearer realm=%q,service=registrytest", h.tokenSrv.URL)
   237  	authHeader := req.Header.Get("Authorization")
   238  	if authHeader == "" {
   239  		w.Header().Set("Www-Authenticate", wwwAuth)
   240  		writeError(w, ociregistry.ErrUnauthorized)
   241  		return
   242  	}
   243  	kind, token, ok := strings.Cut(authHeader, " ")
   244  	if !ok || kind != "Bearer" {
   245  		w.Header().Set("Www-Authenticate", wwwAuth)
   246  		writeError(w, ociregistry.ErrUnauthorized)
   247  		return
   248  	}
   249  	switch token {
   250  	case registryAuthToken:
   251  		// User is authorized.
   252  	case registryUnauthToken:
   253  		writeError(w, ociregistry.ErrDenied)
   254  		return
   255  	default:
   256  		// If we don't recognize the token, then presumably
   257  		// the client isn't authenticated so it's 401 not 403.
   258  		w.Header().Set("Www-Authenticate", wwwAuth)
   259  		writeError(w, ociregistry.ErrUnauthorized)
   260  		return
   261  	}
   262  	// If the underlying registry returns a 401 error,
   263  	// we need to add the Www-Authenticate header.
   264  	// As there's no way to get ociserver to do it,
   265  	// we hack it by wrapping the ResponseWriter
   266  	// with an implementation that does.
   267  	h.registry.ServeHTTP(&authHeaderWriter{
   268  		wwwAuth:        wwwAuth,
   269  		ResponseWriter: w,
   270  	}, req)
   271  }
   272  
   273  func (h *registryHandler) serveDirectAuth(w http.ResponseWriter, req *http.Request) {
   274  	auth := req.Header.Get("Authorization")
   275  	if auth == "" {
   276  		if h.auth.BearerToken != "" {
   277  			// Note that this lacks information like the realm,
   278  			// but we don't need it for our test cases yet.
   279  			w.Header().Set("Www-Authenticate", "Bearer service=registry")
   280  		} else {
   281  			w.Header().Set("Www-Authenticate", "Basic service=registry")
   282  		}
   283  		writeError(w, fmt.Errorf("%w: no credentials", ociregistry.ErrUnauthorized))
   284  		return
   285  	}
   286  	if h.auth.BearerToken != "" {
   287  		token, ok := strings.CutPrefix(auth, "Bearer ")
   288  		if !ok || token != h.auth.BearerToken {
   289  			writeError(w, fmt.Errorf("%w: invalid bearer credentials", ociregistry.ErrUnauthorized))
   290  			return
   291  		}
   292  	} else {
   293  		username, password, ok := req.BasicAuth()
   294  		if !ok || username != h.auth.Username || password != h.auth.Password {
   295  			writeError(w, fmt.Errorf("%w: invalid user-password credentials", ociregistry.ErrUnauthorized))
   296  			return
   297  		}
   298  	}
   299  	h.registry.ServeHTTP(w, req)
   300  }
   301  
   302  func tokenHandler(*AuthConfig) http.Handler {
   303  	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   304  		if req.Method != "POST" {
   305  			http.Error(w, "only POST supported", http.StatusMethodNotAllowed)
   306  			return
   307  		}
   308  		req.ParseForm()
   309  		if req.Form.Get("service") != "registrytest" {
   310  			http.Error(w, "invalid service", http.StatusBadRequest)
   311  			return
   312  		}
   313  		if req.Form.Get("grant_type") != "refresh_token" {
   314  			http.Error(w, "invalid grant type", http.StatusBadRequest)
   315  			return
   316  		}
   317  		refreshToken := req.Form.Get("refresh_token")
   318  		if refreshToken != "registrytest-refresh" {
   319  			http.Error(w, fmt.Sprintf("invalid refresh token %q", refreshToken), http.StatusForbidden)
   320  			return
   321  		}
   322  		// See ociauth.wireToken for the full JSON format.
   323  		data, _ := json.Marshal(map[string]string{
   324  			"token": registryAuthToken,
   325  		})
   326  		w.Header().Set("Content-Type", "application/json")
   327  		w.Write(data)
   328  	})
   329  }
   330  
   331  func writeError(w http.ResponseWriter, err error) {
   332  	data, httpStatus := ociregistry.MarshalError(err)
   333  	w.Header().Set("Content-Type", "application/json")
   334  	w.WriteHeader(httpStatus)
   335  	w.Write(data)
   336  }
   337  
   338  func pushContent(ctx context.Context, client *modregistry.Client, mods map[module.Version]*moduleContent) error {
   339  	pushed := make(map[module.Version]bool)
   340  	// Iterate over modules in deterministic order.
   341  	// TODO use maps.Keys when available.
   342  	vs := make([]module.Version, 0, len(mods))
   343  	for v := range mods {
   344  		vs = append(vs, v)
   345  	}
   346  	module.Sort(vs)
   347  	for _, v := range vs {
   348  		err := visitDepthFirst(mods, v, func(v module.Version, m *moduleContent) error {
   349  			if pushed[v] {
   350  				return nil
   351  			}
   352  			var zipContent bytes.Buffer
   353  			if err := m.writeZip(&zipContent); err != nil {
   354  				return err
   355  			}
   356  			if err := client.PutModule(ctx, v, bytes.NewReader(zipContent.Bytes()), int64(zipContent.Len())); err != nil {
   357  				return err
   358  			}
   359  			pushed[v] = true
   360  			return nil
   361  		})
   362  		if err != nil {
   363  			return err
   364  		}
   365  	}
   366  	return nil
   367  }
   368  
   369  func visitDepthFirst(mods map[module.Version]*moduleContent, v module.Version, f func(module.Version, *moduleContent) error) error {
   370  	m := mods[v]
   371  	if m == nil {
   372  		return nil
   373  	}
   374  	for _, depv := range m.modFile.DepVersions() {
   375  		if err := visitDepthFirst(mods, depv, f); err != nil {
   376  			return err
   377  		}
   378  	}
   379  	return f(v, m)
   380  }
   381  
   382  type Registry struct {
   383  	srv      *httptest.Server
   384  	tokenSrv *httptest.Server
   385  	host     string
   386  }
   387  
   388  func (r *Registry) Close() {
   389  	r.srv.Close()
   390  	if r.tokenSrv != nil {
   391  		r.tokenSrv.Close()
   392  	}
   393  }
   394  
   395  type authHeaderWriter struct {
   396  	wwwAuth string
   397  	http.ResponseWriter
   398  }
   399  
   400  func (w *authHeaderWriter) WriteHeader(code int) {
   401  	if code == http.StatusUnauthorized && w.Header().Get("Www-Authenticate") == "" {
   402  		w.Header().Set("Www-Authenticate", w.wwwAuth)
   403  	}
   404  	w.ResponseWriter.WriteHeader(code)
   405  }
   406  
   407  // Host returns the hostname for the registry server;
   408  // for example localhost:13455.
   409  //
   410  // The connection can be assumed to be insecure.
   411  func (r *Registry) Host() string {
   412  	return r.host
   413  }
   414  
   415  func getModules(fsys fs.FS) (map[module.Version]*moduleContent, []byte, error) {
   416  	var authConfig []byte
   417  	modules := make(map[string]*moduleContent)
   418  	if err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
   419  		if err != nil {
   420  			// If a filesystem has no entries at all,
   421  			// return zero modules without an error.
   422  			if path == "." && errors.Is(err, fs.ErrNotExist) {
   423  				return fs.SkipAll
   424  			}
   425  			return err
   426  		}
   427  		if d.IsDir() {
   428  			return nil // we're only interested in regular files, not their parent directories
   429  		}
   430  		if path == "auth.json" {
   431  			authConfig, err = fs.ReadFile(fsys, path)
   432  			if err != nil {
   433  				return err
   434  			}
   435  			return nil
   436  		}
   437  		modver, rest, ok := strings.Cut(path, "/")
   438  		if !ok {
   439  			return fmt.Errorf("registry should only contain directories, but found regular file %q", path)
   440  		}
   441  		content := modules[modver]
   442  		if content == nil {
   443  			content = &moduleContent{}
   444  			modules[modver] = content
   445  		}
   446  		data, err := fs.ReadFile(fsys, path)
   447  		if err != nil {
   448  			return err
   449  		}
   450  		content.files = append(content.files, txtar.File{
   451  			Name: rest,
   452  			Data: data,
   453  		})
   454  		return nil
   455  	}); err != nil {
   456  		return nil, nil, err
   457  	}
   458  	for modver, content := range modules {
   459  		if err := content.init(modver); err != nil {
   460  			return nil, nil, fmt.Errorf("cannot initialize module %q: %v", modver, err)
   461  		}
   462  	}
   463  	byVer := map[module.Version]*moduleContent{}
   464  	for _, m := range modules {
   465  		byVer[m.version] = m
   466  	}
   467  	return byVer, authConfig, nil
   468  }
   469  
   470  type moduleContent struct {
   471  	version module.Version
   472  	files   []txtar.File
   473  	modFile *modfile.File
   474  }
   475  
   476  func (c *moduleContent) writeZip(w io.Writer) error {
   477  	return modzip.Create(w, c.version, c.files, txtarFileIO{})
   478  }
   479  
   480  func (c *moduleContent) init(versDir string) error {
   481  	found := false
   482  	for _, f := range c.files {
   483  		if f.Name != "cue.mod/module.cue" {
   484  			continue
   485  		}
   486  		modf, err := modfile.Parse(f.Data, f.Name)
   487  		if err != nil {
   488  			return err
   489  		}
   490  		if found {
   491  			return fmt.Errorf("multiple module.cue files")
   492  		}
   493  		mod := strings.ReplaceAll(modf.ModulePath(), "/", "_") + "_"
   494  		vers := strings.TrimPrefix(versDir, mod)
   495  		if len(vers) == len(versDir) {
   496  			return fmt.Errorf("module path %q in module.cue does not match directory %q", modf.QualifiedModule(), versDir)
   497  		}
   498  		v, err := module.NewVersion(modf.QualifiedModule(), vers)
   499  		if err != nil {
   500  			return fmt.Errorf("cannot make module version: %v", err)
   501  		}
   502  		c.version = v
   503  		c.modFile = modf
   504  		found = true
   505  	}
   506  	if !found {
   507  		return fmt.Errorf("no module.cue file found in %q", versDir)
   508  	}
   509  	return nil
   510  }