cuelang.org/go@v0.10.1/mod/modconfig/modconfig_test.go (about)

     1  package modconfig
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/fs"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"slices"
    15  	"strings"
    16  	"sync/atomic"
    17  	"testing"
    18  	"time"
    19  
    20  	"cuelabs.dev/go/oci/ociregistry/ocimem"
    21  	"cuelabs.dev/go/oci/ociregistry/ociserver"
    22  	"github.com/go-quicktest/qt"
    23  	"golang.org/x/sync/errgroup"
    24  	"golang.org/x/tools/txtar"
    25  
    26  	"cuelang.org/go/internal/cueconfig"
    27  	"cuelang.org/go/internal/cueversion"
    28  	"cuelang.org/go/internal/registrytest"
    29  	"cuelang.org/go/mod/modcache"
    30  	"cuelang.org/go/mod/module"
    31  )
    32  
    33  // TODO: the test below acts as a smoke test for the functionality here,
    34  // but more of the behavior is tested in the cmd/cue script tests.
    35  // We should do more of it here too.
    36  
    37  func TestNewRegistry(t *testing.T) {
    38  	modules := txtar.Parse([]byte(`
    39  -- r1/foo.example_v0.0.1/cue.mod/module.cue --
    40  module: "foo.example@v0"
    41  language: version: "v0.8.0"
    42  deps: "bar.example@v0": v: "v0.0.1"
    43  -- r1/foo.example_v0.0.1/bar/bar.cue --
    44  package bar
    45  -- r1/bar.example_v0.0.1/cue.mod/module.cue --
    46  module: "bar.example@v0"
    47  language: version: "v0.8.0"
    48  -- r1/bar.example_v0.0.1/y/y.cue --
    49  package y
    50  
    51  -- r2/auth.json --
    52  {
    53  	"username": "bob",
    54  	"password": "somePassword"
    55  }
    56  -- r2/bar.example_v0.0.1/cue.mod/module.cue --
    57  module: "bar.example@v0"
    58  language: version: "v0.8.0"
    59  -- r2/bar.example_v0.0.1/x/x.cue --
    60  package x
    61  `))
    62  	fsys, err := txtar.FS(modules)
    63  	qt.Assert(t, qt.IsNil(err))
    64  	r1fs, err := fs.Sub(fsys, "r1")
    65  	qt.Assert(t, qt.IsNil(err))
    66  	r1, err := registrytest.New(r1fs, "")
    67  	qt.Assert(t, qt.IsNil(err))
    68  	r2fs, err := fs.Sub(fsys, "r2")
    69  	qt.Assert(t, qt.IsNil(err))
    70  	r2, err := registrytest.New(r2fs, "")
    71  	qt.Assert(t, qt.IsNil(err))
    72  
    73  	dir := t.TempDir()
    74  	t.Setenv("DOCKER_CONFIG", dir)
    75  	dockerCfg, err := json.Marshal(dockerConfig{
    76  		Auths: map[string]authConfig{
    77  			r2.Host(): {
    78  				Username: "bob",
    79  				Password: "somePassword",
    80  			},
    81  		},
    82  	})
    83  	qt.Assert(t, qt.IsNil(err))
    84  	err = os.WriteFile(filepath.Join(dir, "config.json"), dockerCfg, 0o666)
    85  	qt.Assert(t, qt.IsNil(err))
    86  
    87  	t.Setenv("CUE_REGISTRY",
    88  		fmt.Sprintf("foo.example=%s+insecure,%s+insecure",
    89  			r1.Host(),
    90  			r2.Host(),
    91  		))
    92  	cacheDir := filepath.Join(dir, "cache")
    93  	t.Setenv("CUE_CACHE_DIR", cacheDir)
    94  	t.Cleanup(func() {
    95  		modcache.RemoveAll(cacheDir)
    96  	})
    97  
    98  	var transportInvoked atomic.Bool
    99  	r, err := NewRegistry(&Config{
   100  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   101  			transportInvoked.Store(true)
   102  			return http.DefaultTransport.RoundTrip(req)
   103  		}),
   104  	})
   105  	qt.Assert(t, qt.IsNil(err))
   106  	ctx := context.Background()
   107  	gotRequirements, err := r.Requirements(ctx, module.MustNewVersion("foo.example@v0", "v0.0.1"))
   108  	qt.Assert(t, qt.IsNil(err))
   109  	qt.Assert(t, qt.DeepEquals(gotRequirements, []module.Version{
   110  		module.MustNewVersion("bar.example@v0", "v0.0.1"),
   111  	}))
   112  
   113  	loc, err := r.Fetch(ctx, module.MustNewVersion("bar.example@v0", "v0.0.1"))
   114  	qt.Assert(t, qt.IsNil(err))
   115  	data, err := fs.ReadFile(loc.FS, path.Join(loc.Dir, "x/x.cue"))
   116  	qt.Assert(t, qt.IsNil(err))
   117  	qt.Assert(t, qt.Equals(string(data), "package x\n"))
   118  
   119  	// Check that we can make a Resolver with the same configuration.
   120  	resolver, err := NewResolver(nil)
   121  	qt.Assert(t, qt.IsNil(err))
   122  	gotAllHosts := resolver.AllHosts()
   123  	wantAllHosts := []Host{{Name: r1.Host(), Insecure: true}, {Name: r2.Host(), Insecure: true}}
   124  
   125  	byHostname := func(a, b Host) int { return strings.Compare(a.Name, b.Name) }
   126  	slices.SortFunc(gotAllHosts, byHostname)
   127  	slices.SortFunc(wantAllHosts, byHostname)
   128  
   129  	qt.Assert(t, qt.DeepEquals(gotAllHosts, wantAllHosts))
   130  
   131  	// Check that the underlying custom transport was used.
   132  	qt.Assert(t, qt.IsTrue(transportInvoked.Load()))
   133  }
   134  
   135  func TestDefaultTransportSetsUserAgent(t *testing.T) {
   136  	// This test also checks that providing a nil Config.Transport
   137  	// does the right thing.
   138  
   139  	regFS, err := txtar.FS(txtar.Parse([]byte(`
   140  -- bar.example_v0.0.1/cue.mod/module.cue --
   141  module: "bar.example@v0"
   142  language: version: "v0.8.0"
   143  -- bar.example_v0.0.1/x/x.cue --
   144  package x
   145  `)))
   146  	qt.Assert(t, qt.IsNil(err))
   147  	ctx := context.Background()
   148  	rmem := ocimem.NewWithConfig(&ocimem.Config{ImmutableTags: true})
   149  	err = registrytest.Upload(ctx, rmem, regFS)
   150  	qt.Assert(t, qt.IsNil(err))
   151  	rh := ociserver.New(rmem, nil)
   152  	agent := cueversion.UserAgent("cuelang.org/go")
   153  	checked := false
   154  	checkUserAgentHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   155  		qt.Check(t, qt.Equals(req.UserAgent(), agent))
   156  		checked = true
   157  		rh.ServeHTTP(w, req)
   158  	})
   159  	srv := httptest.NewServer(checkUserAgentHandler)
   160  	u, err := url.Parse(srv.URL)
   161  	qt.Assert(t, qt.IsNil(err))
   162  
   163  	dir := t.TempDir()
   164  	t.Setenv("DOCKER_CONFIG", dir)
   165  	t.Setenv("CUE_REGISTRY", u.Host+"+insecure")
   166  	cacheDir := filepath.Join(dir, "cache")
   167  	t.Setenv("CUE_CACHE_DIR", cacheDir)
   168  	t.Cleanup(func() {
   169  		modcache.RemoveAll(cacheDir)
   170  	})
   171  
   172  	r, err := NewRegistry(nil)
   173  	qt.Assert(t, qt.IsNil(err))
   174  	gotRequirements, err := r.Requirements(ctx, module.MustNewVersion("bar.example@v0", "v0.0.1"))
   175  	qt.Assert(t, qt.IsNil(err))
   176  	qt.Assert(t, qt.HasLen(gotRequirements, 0))
   177  
   178  	qt.Assert(t, qt.IsTrue(checked))
   179  }
   180  
   181  // TestConcurrentTokenRefresh verifies that concurrent OAuth token refreshes,
   182  // including logins.json updates, are properly synchronized.
   183  func TestConcurrentTokenRefresh(t *testing.T) {
   184  	// Start N registry instances, each containing one CUE module and running
   185  	// in its own HTTP server instance. Each instance is protected with its
   186  	// own OAuth token, which is initially expired, requiring a refresh token
   187  	// request upon first invocation.
   188  	var registries [20]struct {
   189  		mod  string
   190  		host string
   191  	}
   192  	var counter int32 = 0
   193  	for i := range registries {
   194  		reg := &registries[i]
   195  		reg.mod = fmt.Sprintf("foo.mod%02d", i)
   196  		fsys, err := txtar.FS(txtar.Parse([]byte(fmt.Sprintf(`
   197  -- %s_v0.0.1/cue.mod/module.cue --
   198  module: "%s@v0"
   199  language: version: "v0.8.0"
   200  -- %s_v0.0.1/bar/bar.cue --
   201  package bar
   202  `, reg.mod, reg.mod, reg.mod))))
   203  		qt.Assert(t, qt.IsNil(err))
   204  		mux := http.NewServeMux()
   205  		r := ocimem.New()
   206  		err = registrytest.Upload(context.Background(), r, fsys)
   207  		qt.Assert(t, qt.IsNil(err))
   208  		rh := ociserver.New(r, nil)
   209  		mux.HandleFunc("/v2/", func(w http.ResponseWriter, r *http.Request) {
   210  			auth := r.Header.Get("Authorization")
   211  			if !strings.HasPrefix(auth, fmt.Sprintf("Bearer access_%d_", i)) {
   212  				w.WriteHeader(401)
   213  				w.Write([]byte(fmt.Sprintf("server %d: unexpected auth header: %s", i, auth)))
   214  				return
   215  			}
   216  			rh.ServeHTTP(w, r)
   217  		})
   218  		mux.HandleFunc("/login/oauth/token", func(w http.ResponseWriter, r *http.Request) {
   219  			ctr := atomic.AddInt32(&counter, 1)
   220  			writeJSON(w, 200, wireToken{
   221  				AccessToken:  fmt.Sprintf("access_%d_%d", i, ctr),
   222  				TokenType:    "Bearer",
   223  				RefreshToken: fmt.Sprintf("refresh_%d", ctr),
   224  				ExpiresIn:    300,
   225  			})
   226  		})
   227  		srv := httptest.NewServer(mux)
   228  		t.Cleanup(srv.Close)
   229  		u, err := url.Parse(srv.URL)
   230  		qt.Assert(t, qt.IsNil(err))
   231  		reg.host = u.Host
   232  	}
   233  
   234  	expiry := time.Now()
   235  	logins := &cueconfig.Logins{
   236  		Registries: map[string]cueconfig.RegistryLogin{},
   237  	}
   238  	registryConf := ""
   239  	for i, reg := range registries {
   240  		logins.Registries[reg.host] = cueconfig.RegistryLogin{
   241  			AccessToken:  fmt.Sprintf("access_%d_x", i),
   242  			TokenType:    "Bearer",
   243  			RefreshToken: "refresh_x",
   244  			Expiry:       &expiry,
   245  		}
   246  		if registryConf != "" {
   247  			registryConf += ","
   248  		}
   249  		registryConf += fmt.Sprintf("%s=%s+insecure", reg.mod, reg.host)
   250  	}
   251  
   252  	dir := t.TempDir()
   253  	configDir := filepath.Join(dir, "config")
   254  	t.Setenv("CUE_CONFIG_DIR", configDir)
   255  	err := os.MkdirAll(configDir, 0o777)
   256  	qt.Assert(t, qt.IsNil(err))
   257  
   258  	// Check logins.json validation.
   259  	logins.Registries["blank"] = cueconfig.RegistryLogin{TokenType: "Bearer"}
   260  	err = cueconfig.WriteLogins(filepath.Join(configDir, "logins.json"), logins)
   261  	delete(logins.Registries, "blank")
   262  	qt.Assert(t, qt.IsNil(err))
   263  	_, err = cueconfig.ReadLogins(filepath.Join(configDir, "logins.json"))
   264  	qt.Assert(t, qt.ErrorMatches(err, "invalid .*logins.json: missing access_token for registry blank"))
   265  
   266  	// Check write-read round-trip.
   267  	err = cueconfig.WriteLogins(filepath.Join(configDir, "logins.json"), logins)
   268  	qt.Assert(t, qt.IsNil(err))
   269  	logins2, err := cueconfig.ReadLogins(filepath.Join(configDir, "logins.json"))
   270  	qt.Assert(t, qt.IsNil(err))
   271  	qt.Assert(t, qt.DeepEquals(logins2, logins))
   272  
   273  	t.Setenv("CUE_REGISTRY", registryConf)
   274  	cacheDir := filepath.Join(dir, "cache")
   275  	t.Setenv("CUE_CACHE_DIR", cacheDir)
   276  	t.Cleanup(func() {
   277  		modcache.RemoveAll(cacheDir)
   278  	})
   279  
   280  	r, err := NewRegistry(nil)
   281  	qt.Assert(t, qt.IsNil(err))
   282  
   283  	g := new(errgroup.Group)
   284  	for i := range registries {
   285  		mod := registries[i].mod
   286  		g.Go(func() error {
   287  			ctx := context.Background()
   288  			loc, err := r.Fetch(ctx, module.MustNewVersion(mod+"@v0", "v0.0.1"))
   289  			if err != nil {
   290  				return err
   291  			}
   292  			data, err := fs.ReadFile(loc.FS, path.Join(loc.Dir, "bar/bar.cue"))
   293  			if err != nil {
   294  				return err
   295  			}
   296  			if string(data) != "package bar\n" {
   297  				return fmt.Errorf("unexpected data: %q", string(data))
   298  			}
   299  			return nil
   300  		})
   301  	}
   302  	err = g.Wait()
   303  	qt.Assert(t, qt.IsNil(err))
   304  	qt.Assert(t, qt.Equals(int(counter), len(registries)))
   305  }
   306  
   307  // dockerConfig describes the minimal subset of the docker
   308  // configuration file necessary to check that authentication
   309  // is correction hooked up.
   310  type dockerConfig struct {
   311  	Auths map[string]authConfig `json:"auths"`
   312  }
   313  
   314  // authConfig contains authorization information for connecting to a Registry.
   315  type authConfig struct {
   316  	Username string `json:"username,omitempty"`
   317  	Password string `json:"password,omitempty"`
   318  }
   319  
   320  type roundTripperFunc func(*http.Request) (*http.Response, error)
   321  
   322  func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   323  	return f(req)
   324  }
   325  
   326  func writeJSON(w http.ResponseWriter, statusCode int, v any) {
   327  	b, err := json.Marshal(v)
   328  	if err != nil {
   329  		// should never happen
   330  		panic(err)
   331  	}
   332  	w.Header().Set("Content-Type", "application/json")
   333  	w.WriteHeader(statusCode)
   334  	w.Write(b)
   335  }
   336  
   337  // wireToken describes the JSON encoding for an OAuth 2.0 token
   338  // as specified in the RFC 6749.
   339  type wireToken struct {
   340  	AccessToken  string `json:"access_token,omitempty"`
   341  	TokenType    string `json:"token_type,omitempty"`
   342  	RefreshToken string `json:"refresh_token,omitempty"`
   343  	ExpiresIn    int    `json:"expires_in,omitempty"`
   344  }