cuelang.org/go@v0.13.0/mod/modconfig/modconfig_test.go (about)

     1  package modconfig
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/fs"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"slices"
    15  	"strings"
    16  	"sync/atomic"
    17  	"testing"
    18  	"time"
    19  
    20  	"cuelabs.dev/go/oci/ociregistry/ocimem"
    21  	"cuelabs.dev/go/oci/ociregistry/ociserver"
    22  	"github.com/go-quicktest/qt"
    23  	"golang.org/x/oauth2"
    24  	"golang.org/x/sync/errgroup"
    25  	"golang.org/x/tools/txtar"
    26  
    27  	"cuelang.org/go/internal/cueconfig"
    28  	"cuelang.org/go/internal/cueversion"
    29  	"cuelang.org/go/mod/modcache"
    30  	"cuelang.org/go/mod/modregistrytest"
    31  	"cuelang.org/go/mod/module"
    32  )
    33  
    34  // TODO: the test below acts as a smoke test for the functionality here,
    35  // but more of the behavior is tested in the cmd/cue script tests.
    36  // We should do more of it here too.
    37  
    38  func TestNewRegistry(t *testing.T) {
    39  	modules := txtar.Parse([]byte(`
    40  -- r1/foo.example_v0.0.1/cue.mod/module.cue --
    41  module: "foo.example@v0"
    42  language: version: "v0.8.0"
    43  deps: "bar.example@v0": v: "v0.0.1"
    44  -- r1/foo.example_v0.0.1/bar/bar.cue --
    45  package bar
    46  -- r1/bar.example_v0.0.1/cue.mod/module.cue --
    47  module: "bar.example@v0"
    48  language: version: "v0.8.0"
    49  -- r1/bar.example_v0.0.1/y/y.cue --
    50  package y
    51  
    52  -- r2/auth.json --
    53  {
    54  	"username": "bob",
    55  	"password": "somePassword"
    56  }
    57  -- r2/bar.example_v0.0.1/cue.mod/module.cue --
    58  module: "bar.example@v0"
    59  language: version: "v0.8.0"
    60  -- r2/bar.example_v0.0.1/x/x.cue --
    61  package x
    62  `))
    63  	fsys, err := txtar.FS(modules)
    64  	qt.Assert(t, qt.IsNil(err))
    65  	r1fs, err := fs.Sub(fsys, "r1")
    66  	qt.Assert(t, qt.IsNil(err))
    67  	r1, err := modregistrytest.New(r1fs, "")
    68  	qt.Assert(t, qt.IsNil(err))
    69  	r2fs, err := fs.Sub(fsys, "r2")
    70  	qt.Assert(t, qt.IsNil(err))
    71  	r2, err := modregistrytest.New(r2fs, "")
    72  	qt.Assert(t, qt.IsNil(err))
    73  
    74  	dir := t.TempDir()
    75  	t.Setenv("DOCKER_CONFIG", dir)
    76  	dockerCfg, err := json.Marshal(dockerConfig{
    77  		Auths: map[string]authConfig{
    78  			r2.Host(): {
    79  				Username: "bob",
    80  				Password: "somePassword",
    81  			},
    82  		},
    83  	})
    84  	qt.Assert(t, qt.IsNil(err))
    85  	err = os.WriteFile(filepath.Join(dir, "config.json"), dockerCfg, 0o666)
    86  	qt.Assert(t, qt.IsNil(err))
    87  
    88  	t.Setenv("CUE_REGISTRY",
    89  		fmt.Sprintf("foo.example=%s+insecure,%s+insecure",
    90  			r1.Host(),
    91  			r2.Host(),
    92  		))
    93  	cacheDir := filepath.Join(dir, "cache")
    94  	t.Setenv("CUE_CACHE_DIR", cacheDir)
    95  	t.Cleanup(func() {
    96  		modcache.RemoveAll(cacheDir)
    97  	})
    98  
    99  	var transportInvoked atomic.Bool
   100  	r, err := NewRegistry(&Config{
   101  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   102  			transportInvoked.Store(true)
   103  			return http.DefaultTransport.RoundTrip(req)
   104  		}),
   105  	})
   106  	qt.Assert(t, qt.IsNil(err))
   107  	ctx := context.Background()
   108  	gotRequirements, err := r.Requirements(ctx, module.MustNewVersion("foo.example@v0", "v0.0.1"))
   109  	qt.Assert(t, qt.IsNil(err))
   110  	qt.Assert(t, qt.DeepEquals(gotRequirements, []module.Version{
   111  		module.MustNewVersion("bar.example@v0", "v0.0.1"),
   112  	}))
   113  
   114  	loc, err := r.Fetch(ctx, module.MustNewVersion("bar.example@v0", "v0.0.1"))
   115  	qt.Assert(t, qt.IsNil(err))
   116  	data, err := fs.ReadFile(loc.FS, path.Join(loc.Dir, "x/x.cue"))
   117  	qt.Assert(t, qt.IsNil(err))
   118  	qt.Assert(t, qt.Equals(string(data), "package x\n"))
   119  
   120  	// Check that we can make a Resolver with the same configuration.
   121  	resolver, err := NewResolver(nil)
   122  	qt.Assert(t, qt.IsNil(err))
   123  	gotAllHosts := resolver.AllHosts()
   124  	wantAllHosts := []Host{{Name: r1.Host(), Insecure: true}, {Name: r2.Host(), Insecure: true}}
   125  
   126  	byHostname := func(a, b Host) int { return strings.Compare(a.Name, b.Name) }
   127  	slices.SortFunc(gotAllHosts, byHostname)
   128  	slices.SortFunc(wantAllHosts, byHostname)
   129  
   130  	qt.Assert(t, qt.DeepEquals(gotAllHosts, wantAllHosts))
   131  
   132  	// Check that the underlying custom transport was used.
   133  	qt.Assert(t, qt.IsTrue(transportInvoked.Load()))
   134  }
   135  
   136  func TestDefaultTransportSetsUserAgent(t *testing.T) {
   137  	// This test also checks that providing a nil Config.Transport
   138  	// does the right thing.
   139  
   140  	regFS, err := txtar.FS(txtar.Parse([]byte(`
   141  -- bar.example_v0.0.1/cue.mod/module.cue --
   142  module: "bar.example@v0"
   143  language: version: "v0.8.0"
   144  -- bar.example_v0.0.1/x/x.cue --
   145  package x
   146  `)))
   147  	qt.Assert(t, qt.IsNil(err))
   148  	ctx := context.Background()
   149  	rmem := ocimem.NewWithConfig(&ocimem.Config{ImmutableTags: true})
   150  	err = modregistrytest.Upload(ctx, rmem, regFS)
   151  	qt.Assert(t, qt.IsNil(err))
   152  	rh := ociserver.New(rmem, nil)
   153  	agent := cueversion.UserAgent("cuelang.org/go")
   154  	checked := false
   155  	checkUserAgentHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   156  		qt.Check(t, qt.Equals(req.UserAgent(), agent))
   157  		checked = true
   158  		rh.ServeHTTP(w, req)
   159  	})
   160  	srv := httptest.NewServer(checkUserAgentHandler)
   161  	u, err := url.Parse(srv.URL)
   162  	qt.Assert(t, qt.IsNil(err))
   163  
   164  	dir := t.TempDir()
   165  	t.Setenv("DOCKER_CONFIG", dir)
   166  	t.Setenv("CUE_REGISTRY", u.Host+"+insecure")
   167  	cacheDir := filepath.Join(dir, "cache")
   168  	t.Setenv("CUE_CACHE_DIR", cacheDir)
   169  	t.Cleanup(func() {
   170  		modcache.RemoveAll(cacheDir)
   171  	})
   172  
   173  	r, err := NewRegistry(nil)
   174  	qt.Assert(t, qt.IsNil(err))
   175  	gotRequirements, err := r.Requirements(ctx, module.MustNewVersion("bar.example@v0", "v0.0.1"))
   176  	qt.Assert(t, qt.IsNil(err))
   177  	qt.Assert(t, qt.HasLen(gotRequirements, 0))
   178  
   179  	qt.Assert(t, qt.IsTrue(checked))
   180  }
   181  
   182  // TestConcurrentTokenRefresh verifies that concurrent OAuth token refreshes,
   183  // including logins.json updates, are properly synchronized.
   184  func TestConcurrentTokenRefresh(t *testing.T) {
   185  	// Start N registry instances, each containing one CUE module and running
   186  	// in its own HTTP server instance. Each instance is protected with its
   187  	// own OAuth token, which is initially expired, requiring a refresh token
   188  	// request upon first invocation.
   189  	var registries [20]struct {
   190  		mod  string
   191  		host string
   192  	}
   193  	var counter int32 = 0
   194  	for i := range registries {
   195  		reg := &registries[i]
   196  		reg.mod = fmt.Sprintf("foo.mod%02d", i)
   197  		fsys, err := txtar.FS(txtar.Parse([]byte(fmt.Sprintf(`
   198  -- %s_v0.0.1/cue.mod/module.cue --
   199  module: "%s@v0"
   200  language: version: "v0.8.0"
   201  -- %s_v0.0.1/bar/bar.cue --
   202  package bar
   203  `, reg.mod, reg.mod, reg.mod))))
   204  		qt.Assert(t, qt.IsNil(err))
   205  		mux := http.NewServeMux()
   206  		r := ocimem.New()
   207  		err = modregistrytest.Upload(context.Background(), r, fsys)
   208  		qt.Assert(t, qt.IsNil(err))
   209  		rh := ociserver.New(r, nil)
   210  		mux.HandleFunc("/v2/", func(w http.ResponseWriter, r *http.Request) {
   211  			auth := r.Header.Get("Authorization")
   212  			if !strings.HasPrefix(auth, fmt.Sprintf("Bearer access_%d_", i)) {
   213  				w.WriteHeader(401)
   214  				fmt.Fprintf(w, "server %d: unexpected auth header: %s", i, auth)
   215  				return
   216  			}
   217  			rh.ServeHTTP(w, r)
   218  		})
   219  		mux.HandleFunc("/login/oauth/token", func(w http.ResponseWriter, r *http.Request) {
   220  			ctr := atomic.AddInt32(&counter, 1)
   221  			writeJSON(w, 200, oauth2.Token{
   222  				AccessToken:  fmt.Sprintf("access_%d_%d", i, ctr),
   223  				TokenType:    "Bearer",
   224  				RefreshToken: fmt.Sprintf("refresh_%d", ctr),
   225  				ExpiresIn:    300,
   226  			})
   227  		})
   228  		srv := httptest.NewServer(mux)
   229  		t.Cleanup(srv.Close)
   230  		u, err := url.Parse(srv.URL)
   231  		qt.Assert(t, qt.IsNil(err))
   232  		reg.host = u.Host
   233  	}
   234  
   235  	expiry := time.Now()
   236  	logins := &cueconfig.Logins{
   237  		Registries: map[string]cueconfig.RegistryLogin{},
   238  	}
   239  	registryConf := ""
   240  	for i, reg := range registries {
   241  		logins.Registries[reg.host] = cueconfig.RegistryLogin{
   242  			AccessToken:  fmt.Sprintf("access_%d_x", i),
   243  			TokenType:    "Bearer",
   244  			RefreshToken: "refresh_x",
   245  			Expiry:       &expiry,
   246  		}
   247  		if registryConf != "" {
   248  			registryConf += ","
   249  		}
   250  		registryConf += fmt.Sprintf("%s=%s+insecure", reg.mod, reg.host)
   251  	}
   252  
   253  	dir := t.TempDir()
   254  	configDir := filepath.Join(dir, "config")
   255  	t.Setenv("CUE_CONFIG_DIR", configDir)
   256  	err := os.MkdirAll(configDir, 0o777)
   257  	qt.Assert(t, qt.IsNil(err))
   258  
   259  	// Check logins.json validation.
   260  	logins.Registries["blank"] = cueconfig.RegistryLogin{TokenType: "Bearer"}
   261  	err = cueconfig.WriteLogins(filepath.Join(configDir, "logins.json"), logins)
   262  	delete(logins.Registries, "blank")
   263  	qt.Assert(t, qt.IsNil(err))
   264  	_, err = cueconfig.ReadLogins(filepath.Join(configDir, "logins.json"))
   265  	qt.Assert(t, qt.ErrorMatches(err, "invalid .*logins.json: missing access_token for registry blank"))
   266  
   267  	// Check write-read round-trip.
   268  	err = cueconfig.WriteLogins(filepath.Join(configDir, "logins.json"), logins)
   269  	qt.Assert(t, qt.IsNil(err))
   270  	logins2, err := cueconfig.ReadLogins(filepath.Join(configDir, "logins.json"))
   271  	qt.Assert(t, qt.IsNil(err))
   272  	qt.Assert(t, qt.DeepEquals(logins2, logins))
   273  
   274  	t.Setenv("CUE_REGISTRY", registryConf)
   275  	cacheDir := filepath.Join(dir, "cache")
   276  	t.Setenv("CUE_CACHE_DIR", cacheDir)
   277  	t.Cleanup(func() {
   278  		modcache.RemoveAll(cacheDir)
   279  	})
   280  
   281  	r, err := NewRegistry(nil)
   282  	qt.Assert(t, qt.IsNil(err))
   283  
   284  	g := new(errgroup.Group)
   285  	for i := range registries {
   286  		mod := registries[i].mod
   287  		g.Go(func() error {
   288  			ctx := context.Background()
   289  			loc, err := r.Fetch(ctx, module.MustNewVersion(mod+"@v0", "v0.0.1"))
   290  			if err != nil {
   291  				return err
   292  			}
   293  			data, err := fs.ReadFile(loc.FS, path.Join(loc.Dir, "bar/bar.cue"))
   294  			if err != nil {
   295  				return err
   296  			}
   297  			if string(data) != "package bar\n" {
   298  				return fmt.Errorf("unexpected data: %q", string(data))
   299  			}
   300  			return nil
   301  		})
   302  	}
   303  	err = g.Wait()
   304  	qt.Assert(t, qt.IsNil(err))
   305  	qt.Assert(t, qt.Equals(int(counter), len(registries)))
   306  }
   307  
   308  // dockerConfig describes the minimal subset of the docker
   309  // configuration file necessary to check that authentication
   310  // is correction hooked up.
   311  type dockerConfig struct {
   312  	Auths map[string]authConfig `json:"auths"`
   313  }
   314  
   315  // authConfig contains authorization information for connecting to a Registry.
   316  type authConfig struct {
   317  	Username string `json:"username,omitempty"`
   318  	Password string `json:"password,omitempty"`
   319  }
   320  
   321  type roundTripperFunc func(*http.Request) (*http.Response, error)
   322  
   323  func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   324  	return f(req)
   325  }
   326  
   327  func writeJSON(w http.ResponseWriter, statusCode int, v any) {
   328  	b, err := json.Marshal(v)
   329  	if err != nil {
   330  		// should never happen
   331  		panic(err)
   332  	}
   333  	w.Header().Set("Content-Type", "application/json")
   334  	w.WriteHeader(statusCode)
   335  	w.Write(b)
   336  }