cuelang.org/go@v0.10.1/mod/modconfig/modconfig.go (about)

     1  // Package modconfig provides access to the standard CUE
     2  // module configuration, including registry access and authorization.
     3  package modconfig
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"io/fs"
    10  	"net/http"
    11  	"os"
    12  	"strings"
    13  	"sync"
    14  
    15  	"cuelabs.dev/go/oci/ociregistry"
    16  	"cuelabs.dev/go/oci/ociregistry/ociauth"
    17  	"cuelabs.dev/go/oci/ociregistry/ociclient"
    18  	"golang.org/x/oauth2"
    19  
    20  	"cuelang.org/go/internal/cueconfig"
    21  	"cuelang.org/go/internal/cueversion"
    22  	"cuelang.org/go/internal/mod/modload"
    23  	"cuelang.org/go/internal/mod/modresolve"
    24  	"cuelang.org/go/mod/modcache"
    25  	"cuelang.org/go/mod/modregistry"
    26  	"cuelang.org/go/mod/module"
    27  )
    28  
    29  // Registry is used to access CUE modules from external sources.
    30  type Registry interface {
    31  	// Requirements returns a list of the modules required by the given module
    32  	// version.
    33  	Requirements(ctx context.Context, m module.Version) ([]module.Version, error)
    34  
    35  	// Fetch returns the location of the contents for the given module
    36  	// version, downloading it if necessary.
    37  	Fetch(ctx context.Context, m module.Version) (module.SourceLoc, error)
    38  
    39  	// ModuleVersions returns all the versions for the module with the
    40  	// given path, which should contain a major version.
    41  	ModuleVersions(ctx context.Context, mpath string) ([]string, error)
    42  }
    43  
    44  // We don't want to make modload part of the cue/load API,
    45  // so we define the above type independently, but we want
    46  // it to be interchangeable, so check that statically here.
    47  var (
    48  	_ Registry         = modload.Registry(nil)
    49  	_ modload.Registry = Registry(nil)
    50  )
    51  
    52  // DefaultRegistry is the default registry host.
    53  const DefaultRegistry = "registry.cue.works"
    54  
    55  // Resolver implements [modregistry.Resolver] in terms of the
    56  // CUE registry configuration file and auth configuration.
    57  type Resolver struct {
    58  	resolver    modresolve.LocationResolver
    59  	newRegistry func(host string, insecure bool) (ociregistry.Interface, error)
    60  
    61  	mu         sync.Mutex
    62  	registries map[string]ociregistry.Interface
    63  }
    64  
    65  // Config provides the starting point for the configuration.
    66  type Config struct {
    67  	// TODO allow for a custom resolver to be passed in.
    68  
    69  	// Transport is used to make the underlying HTTP requests.
    70  	// If it's nil, [http.DefaultTransport] will be used.
    71  	Transport http.RoundTripper
    72  
    73  	// Env provides environment variable values. If this is nil,
    74  	// the current process's environment will be used.
    75  	Env []string
    76  
    77  	// ClientType is used as part of the User-Agent header
    78  	// that's added in each outgoing HTTP request.
    79  	// If it's empty, it defaults to "cuelang.org/go".
    80  	ClientType string
    81  }
    82  
    83  // NewResolver returns an implementation of [modregistry.Resolver]
    84  // that uses cfg to guide registry resolution. If cfg is nil, it's
    85  // equivalent to passing pointer to a zero Config struct.
    86  //
    87  // It consults the same environment variables used by the
    88  // cue command.
    89  //
    90  // The contents of the configuration will not be mutated.
    91  func NewResolver(cfg *Config) (*Resolver, error) {
    92  	cfg = newRef(cfg)
    93  	cfg.Transport = cueversion.NewTransport(cfg.ClientType, cfg.Transport)
    94  	getenv := getenvFunc(cfg.Env)
    95  	var configData []byte
    96  	var configPath string
    97  	cueRegistry := getenv("CUE_REGISTRY")
    98  	kind, rest, _ := strings.Cut(cueRegistry, ":")
    99  	switch kind {
   100  	case "file":
   101  		data, err := os.ReadFile(rest)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  		configData, configPath = data, rest
   106  	case "inline":
   107  		configData, configPath = []byte(rest), "$CUE_REGISTRY"
   108  	case "simple":
   109  		cueRegistry = rest
   110  	}
   111  	var resolver modresolve.LocationResolver
   112  	var err error
   113  	if configPath != "" {
   114  		resolver, err = modresolve.ParseConfig(configData, configPath, DefaultRegistry)
   115  	} else {
   116  		resolver, err = modresolve.ParseCUERegistry(cueRegistry, DefaultRegistry)
   117  	}
   118  	if err != nil {
   119  		return nil, fmt.Errorf("bad value for $CUE_REGISTRY: %v", err)
   120  	}
   121  	return &Resolver{
   122  		resolver: resolver,
   123  		newRegistry: func(host string, insecure bool) (ociregistry.Interface, error) {
   124  			return ociclient.New(host, &ociclient.Options{
   125  				Insecure: insecure,
   126  				Transport: &cueLoginsTransport{
   127  					getenv: getenv,
   128  					cfg:    cfg,
   129  				},
   130  			})
   131  		},
   132  		registries: make(map[string]ociregistry.Interface),
   133  	}, nil
   134  }
   135  
   136  // Host represents a registry host name and whether
   137  // it should be accessed via a secure connection or not.
   138  type Host = modresolve.Host
   139  
   140  // AllHosts returns all the registry hosts that the resolver might resolve to,
   141  // ordered lexically by hostname.
   142  func (r *Resolver) AllHosts() []Host {
   143  	return r.resolver.AllHosts()
   144  }
   145  
   146  // HostLocation represents a registry host and a location with it.
   147  type HostLocation = modresolve.Location
   148  
   149  // ResolveToLocation returns the host location for the given module path and version
   150  // without creating a Registry instance for it.
   151  func (r *Resolver) ResolveToLocation(mpath string, version string) (HostLocation, bool) {
   152  	return r.resolver.ResolveToLocation(mpath, version)
   153  }
   154  
   155  // ResolveToRegistry implements [modregistry.Resolver.ResolveToRegistry].
   156  func (r *Resolver) ResolveToRegistry(mpath string, version string) (modregistry.RegistryLocation, error) {
   157  	loc, ok := r.resolver.ResolveToLocation(mpath, version)
   158  	if !ok {
   159  		// This can happen when mpath is invalid, which should not
   160  		// happen in practice, as the only caller is modregistry which
   161  		// vets module paths before calling Resolve.
   162  		//
   163  		// It can also happen when the user has explicitly configured a "none"
   164  		// registry to avoid falling back to a default registry.
   165  		return modregistry.RegistryLocation{}, fmt.Errorf("cannot resolve %s (version %q) to registry: %w", mpath, version, modregistry.ErrRegistryNotFound)
   166  	}
   167  	r.mu.Lock()
   168  	defer r.mu.Unlock()
   169  	reg := r.registries[loc.Host]
   170  	if reg == nil {
   171  		reg1, err := r.newRegistry(loc.Host, loc.Insecure)
   172  		if err != nil {
   173  			return modregistry.RegistryLocation{}, fmt.Errorf("cannot make client: %v", err)
   174  		}
   175  		r.registries[loc.Host] = reg1
   176  		reg = reg1
   177  	}
   178  	return modregistry.RegistryLocation{
   179  		Registry:   reg,
   180  		Repository: loc.Repository,
   181  		Tag:        loc.Tag,
   182  	}, nil
   183  }
   184  
   185  // cueLoginsTransport implements [http.RoundTripper] by using
   186  // tokens from the CUE login information when available, falling
   187  // back to using the standard [ociauth] transport implementation.
   188  type cueLoginsTransport struct {
   189  	cfg    *Config
   190  	getenv func(string) string
   191  
   192  	// initOnce guards initErr, logins, and transport.
   193  	initOnce sync.Once
   194  	initErr  error
   195  	// loginsMu guards the logins pointer below.
   196  	// Note that an instance of cueconfig.Logins is read-only and
   197  	// does not have to be guarded.
   198  	loginsMu sync.Mutex
   199  	logins   *cueconfig.Logins
   200  	// transport holds the underlying transport. This wraps
   201  	// t.cfg.Transport.
   202  	transport http.RoundTripper
   203  
   204  	// mu guards the fields below.
   205  	mu sync.Mutex
   206  
   207  	// cachedTransports holds a transport per host.
   208  	// This is needed because the oauth2 API requires a
   209  	// different client for each host. Each of these transports
   210  	// wraps the transport above.
   211  	cachedTransports map[string]http.RoundTripper
   212  }
   213  
   214  func (t *cueLoginsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   215  	// Return an error lazily on the first request because if the
   216  	// user isn't doing anything that requires a registry, we
   217  	// shouldn't complain about reading a bad configuration file.
   218  	if err := t.init(); err != nil {
   219  		return nil, err
   220  	}
   221  
   222  	t.loginsMu.Lock()
   223  	logins := t.logins
   224  	t.loginsMu.Unlock()
   225  
   226  	if logins == nil {
   227  		return t.transport.RoundTrip(req)
   228  	}
   229  	// TODO: note that a CUE registry may include a path prefix,
   230  	// so using solely the host will not work with such a path.
   231  	// Can we do better here, perhaps keeping the path prefix up to "/v2/"?
   232  	host := req.URL.Host
   233  	login, ok := logins.Registries[host]
   234  	if !ok {
   235  		return t.transport.RoundTrip(req)
   236  	}
   237  
   238  	t.mu.Lock()
   239  	transport := t.cachedTransports[host]
   240  	if transport == nil {
   241  		tok := cueconfig.TokenFromLogin(login)
   242  		oauthCfg := cueconfig.RegistryOAuthConfig(Host{
   243  			Name:     host,
   244  			Insecure: req.URL.Scheme == "http",
   245  		})
   246  
   247  		// Make the oauth client use the transport that was set up
   248  		// in init.
   249  		ctx := context.WithValue(req.Context(), oauth2.HTTPClient, &http.Client{
   250  			Transport: t.transport,
   251  		})
   252  		transport = oauth2.NewClient(ctx,
   253  			&cachingTokenSource{
   254  				updateFunc: func(tok *oauth2.Token) error {
   255  					return t.updateLogin(host, tok)
   256  				},
   257  				base: oauthCfg.TokenSource(ctx, tok),
   258  				t:    tok,
   259  			},
   260  		).Transport
   261  		t.cachedTransports[host] = transport
   262  	}
   263  	// Unlock immediately so we don't hold the lock for the entire
   264  	// request, which would preclude any concurrency when
   265  	// making HTTP requests.
   266  	t.mu.Unlock()
   267  	return transport.RoundTrip(req)
   268  }
   269  
   270  func (t *cueLoginsTransport) updateLogin(host string, new *oauth2.Token) error {
   271  	// Reload the logins file in case another process changed it in the meantime.
   272  	loginsPath, err := cueconfig.LoginConfigPath(t.getenv)
   273  	if err != nil {
   274  		// TODO: this should never fail. Log a warning.
   275  		return nil
   276  	}
   277  
   278  	// Lock the logins for the entire duration of the update to avoid races
   279  	t.loginsMu.Lock()
   280  	defer t.loginsMu.Unlock()
   281  
   282  	logins, err := cueconfig.UpdateRegistryLogin(loginsPath, host, new)
   283  	if err != nil {
   284  		return err
   285  	}
   286  
   287  	t.logins = logins
   288  
   289  	return nil
   290  }
   291  
   292  func (t *cueLoginsTransport) init() error {
   293  	t.initOnce.Do(func() {
   294  		t.initErr = t._init()
   295  	})
   296  	return t.initErr
   297  }
   298  
   299  func (t *cueLoginsTransport) _init() error {
   300  	// If a registry was authenticated via `cue login`, use that.
   301  	// If not, fall back to authentication via Docker's config.json.
   302  	// Note that the order below is backwards, since we layer interfaces.
   303  
   304  	config, err := ociauth.LoadWithEnv(nil, t.cfg.Env)
   305  	if err != nil {
   306  		return fmt.Errorf("cannot load OCI auth configuration: %v", err)
   307  	}
   308  	t.transport = ociauth.NewStdTransport(ociauth.StdTransportParams{
   309  		Config:    config,
   310  		Transport: t.cfg.Transport,
   311  	})
   312  
   313  	// If we can't locate a logins.json file at all, then we'll continue.
   314  	// We only refuse to continue if we find an invalid logins.json file.
   315  	loginsPath, err := cueconfig.LoginConfigPath(t.getenv)
   316  	if err != nil {
   317  		// TODO: this should never fail. Log a warning.
   318  		return nil
   319  	}
   320  	logins, err := cueconfig.ReadLogins(loginsPath)
   321  	if errors.Is(err, fs.ErrNotExist) {
   322  		return nil
   323  	}
   324  	if err != nil {
   325  		return fmt.Errorf("cannot load CUE registry logins: %v", err)
   326  	}
   327  	t.logins = logins
   328  	t.cachedTransports = make(map[string]http.RoundTripper)
   329  	return nil
   330  }
   331  
   332  // NewRegistry returns an implementation of the Registry
   333  // interface suitable for passing to [load.Instances].
   334  // It uses the standard CUE cache directory.
   335  func NewRegistry(cfg *Config) (Registry, error) {
   336  	cfg = newRef(cfg)
   337  	resolver, err := NewResolver(cfg)
   338  	if err != nil {
   339  		return nil, err
   340  	}
   341  	cacheDir, err := cueconfig.CacheDir(getenvFunc(cfg.Env))
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  	return modcache.New(modregistry.NewClientWithResolver(resolver), cacheDir)
   346  }
   347  
   348  func getenvFunc(env []string) func(string) string {
   349  	if env == nil {
   350  		return os.Getenv
   351  	}
   352  	return func(key string) string {
   353  		for i := len(env) - 1; i >= 0; i-- {
   354  			if e := env[i]; len(e) >= len(key)+1 && e[len(key)] == '=' && e[:len(key)] == key {
   355  				return e[len(key)+1:]
   356  			}
   357  		}
   358  		return ""
   359  	}
   360  }
   361  
   362  func newRef[T any](x *T) *T {
   363  	var x1 T
   364  	if x != nil {
   365  		x1 = *x
   366  	}
   367  	return &x1
   368  }
   369  
   370  // cachingTokenSource works similar to oauth2.ReuseTokenSource, except that it
   371  // also exposes a hook to get a hold of the refreshed token, so that it can be
   372  // stored in persistent storage.
   373  type cachingTokenSource struct {
   374  	updateFunc func(tok *oauth2.Token) error
   375  	base       oauth2.TokenSource // called when t is expired
   376  
   377  	mu sync.Mutex // guards t
   378  	t  *oauth2.Token
   379  }
   380  
   381  func (s *cachingTokenSource) Token() (*oauth2.Token, error) {
   382  	s.mu.Lock()
   383  	t := s.t
   384  
   385  	if t.Valid() {
   386  		s.mu.Unlock()
   387  		return t, nil
   388  	}
   389  
   390  	t, err := s.base.Token()
   391  	if err != nil {
   392  		s.mu.Unlock()
   393  		return nil, err
   394  	}
   395  
   396  	s.t = t
   397  	s.mu.Unlock()
   398  
   399  	err = s.updateFunc(t)
   400  	if err != nil {
   401  		return nil, err
   402  	}
   403  
   404  	return t, nil
   405  }