cuelang.org/go@v0.10.1/mod/modconfig/modconfig_test.go (about) 1 package modconfig 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io/fs" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "os" 12 "path" 13 "path/filepath" 14 "slices" 15 "strings" 16 "sync/atomic" 17 "testing" 18 "time" 19 20 "cuelabs.dev/go/oci/ociregistry/ocimem" 21 "cuelabs.dev/go/oci/ociregistry/ociserver" 22 "github.com/go-quicktest/qt" 23 "golang.org/x/sync/errgroup" 24 "golang.org/x/tools/txtar" 25 26 "cuelang.org/go/internal/cueconfig" 27 "cuelang.org/go/internal/cueversion" 28 "cuelang.org/go/internal/registrytest" 29 "cuelang.org/go/mod/modcache" 30 "cuelang.org/go/mod/module" 31 ) 32 33 // TODO: the test below acts as a smoke test for the functionality here, 34 // but more of the behavior is tested in the cmd/cue script tests. 35 // We should do more of it here too. 36 37 func TestNewRegistry(t *testing.T) { 38 modules := txtar.Parse([]byte(` 39 -- r1/foo.example_v0.0.1/cue.mod/module.cue -- 40 module: "foo.example@v0" 41 language: version: "v0.8.0" 42 deps: "bar.example@v0": v: "v0.0.1" 43 -- r1/foo.example_v0.0.1/bar/bar.cue -- 44 package bar 45 -- r1/bar.example_v0.0.1/cue.mod/module.cue -- 46 module: "bar.example@v0" 47 language: version: "v0.8.0" 48 -- r1/bar.example_v0.0.1/y/y.cue -- 49 package y 50 51 -- r2/auth.json -- 52 { 53 "username": "bob", 54 "password": "somePassword" 55 } 56 -- r2/bar.example_v0.0.1/cue.mod/module.cue -- 57 module: "bar.example@v0" 58 language: version: "v0.8.0" 59 -- r2/bar.example_v0.0.1/x/x.cue -- 60 package x 61 `)) 62 fsys, err := txtar.FS(modules) 63 qt.Assert(t, qt.IsNil(err)) 64 r1fs, err := fs.Sub(fsys, "r1") 65 qt.Assert(t, qt.IsNil(err)) 66 r1, err := registrytest.New(r1fs, "") 67 qt.Assert(t, qt.IsNil(err)) 68 r2fs, err := fs.Sub(fsys, "r2") 69 qt.Assert(t, qt.IsNil(err)) 70 r2, err := registrytest.New(r2fs, "") 71 qt.Assert(t, qt.IsNil(err)) 72 73 dir := t.TempDir() 74 t.Setenv("DOCKER_CONFIG", dir) 75 dockerCfg, err := json.Marshal(dockerConfig{ 76 Auths: map[string]authConfig{ 77 r2.Host(): { 78 Username: "bob", 79 Password: "somePassword", 80 }, 81 }, 82 }) 83 qt.Assert(t, qt.IsNil(err)) 84 err = os.WriteFile(filepath.Join(dir, "config.json"), dockerCfg, 0o666) 85 qt.Assert(t, qt.IsNil(err)) 86 87 t.Setenv("CUE_REGISTRY", 88 fmt.Sprintf("foo.example=%s+insecure,%s+insecure", 89 r1.Host(), 90 r2.Host(), 91 )) 92 cacheDir := filepath.Join(dir, "cache") 93 t.Setenv("CUE_CACHE_DIR", cacheDir) 94 t.Cleanup(func() { 95 modcache.RemoveAll(cacheDir) 96 }) 97 98 var transportInvoked atomic.Bool 99 r, err := NewRegistry(&Config{ 100 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { 101 transportInvoked.Store(true) 102 return http.DefaultTransport.RoundTrip(req) 103 }), 104 }) 105 qt.Assert(t, qt.IsNil(err)) 106 ctx := context.Background() 107 gotRequirements, err := r.Requirements(ctx, module.MustNewVersion("foo.example@v0", "v0.0.1")) 108 qt.Assert(t, qt.IsNil(err)) 109 qt.Assert(t, qt.DeepEquals(gotRequirements, []module.Version{ 110 module.MustNewVersion("bar.example@v0", "v0.0.1"), 111 })) 112 113 loc, err := r.Fetch(ctx, module.MustNewVersion("bar.example@v0", "v0.0.1")) 114 qt.Assert(t, qt.IsNil(err)) 115 data, err := fs.ReadFile(loc.FS, path.Join(loc.Dir, "x/x.cue")) 116 qt.Assert(t, qt.IsNil(err)) 117 qt.Assert(t, qt.Equals(string(data), "package x\n")) 118 119 // Check that we can make a Resolver with the same configuration. 120 resolver, err := NewResolver(nil) 121 qt.Assert(t, qt.IsNil(err)) 122 gotAllHosts := resolver.AllHosts() 123 wantAllHosts := []Host{{Name: r1.Host(), Insecure: true}, {Name: r2.Host(), Insecure: true}} 124 125 byHostname := func(a, b Host) int { return strings.Compare(a.Name, b.Name) } 126 slices.SortFunc(gotAllHosts, byHostname) 127 slices.SortFunc(wantAllHosts, byHostname) 128 129 qt.Assert(t, qt.DeepEquals(gotAllHosts, wantAllHosts)) 130 131 // Check that the underlying custom transport was used. 132 qt.Assert(t, qt.IsTrue(transportInvoked.Load())) 133 } 134 135 func TestDefaultTransportSetsUserAgent(t *testing.T) { 136 // This test also checks that providing a nil Config.Transport 137 // does the right thing. 138 139 regFS, err := txtar.FS(txtar.Parse([]byte(` 140 -- bar.example_v0.0.1/cue.mod/module.cue -- 141 module: "bar.example@v0" 142 language: version: "v0.8.0" 143 -- bar.example_v0.0.1/x/x.cue -- 144 package x 145 `))) 146 qt.Assert(t, qt.IsNil(err)) 147 ctx := context.Background() 148 rmem := ocimem.NewWithConfig(&ocimem.Config{ImmutableTags: true}) 149 err = registrytest.Upload(ctx, rmem, regFS) 150 qt.Assert(t, qt.IsNil(err)) 151 rh := ociserver.New(rmem, nil) 152 agent := cueversion.UserAgent("cuelang.org/go") 153 checked := false 154 checkUserAgentHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 155 qt.Check(t, qt.Equals(req.UserAgent(), agent)) 156 checked = true 157 rh.ServeHTTP(w, req) 158 }) 159 srv := httptest.NewServer(checkUserAgentHandler) 160 u, err := url.Parse(srv.URL) 161 qt.Assert(t, qt.IsNil(err)) 162 163 dir := t.TempDir() 164 t.Setenv("DOCKER_CONFIG", dir) 165 t.Setenv("CUE_REGISTRY", u.Host+"+insecure") 166 cacheDir := filepath.Join(dir, "cache") 167 t.Setenv("CUE_CACHE_DIR", cacheDir) 168 t.Cleanup(func() { 169 modcache.RemoveAll(cacheDir) 170 }) 171 172 r, err := NewRegistry(nil) 173 qt.Assert(t, qt.IsNil(err)) 174 gotRequirements, err := r.Requirements(ctx, module.MustNewVersion("bar.example@v0", "v0.0.1")) 175 qt.Assert(t, qt.IsNil(err)) 176 qt.Assert(t, qt.HasLen(gotRequirements, 0)) 177 178 qt.Assert(t, qt.IsTrue(checked)) 179 } 180 181 // TestConcurrentTokenRefresh verifies that concurrent OAuth token refreshes, 182 // including logins.json updates, are properly synchronized. 183 func TestConcurrentTokenRefresh(t *testing.T) { 184 // Start N registry instances, each containing one CUE module and running 185 // in its own HTTP server instance. Each instance is protected with its 186 // own OAuth token, which is initially expired, requiring a refresh token 187 // request upon first invocation. 188 var registries [20]struct { 189 mod string 190 host string 191 } 192 var counter int32 = 0 193 for i := range registries { 194 reg := ®istries[i] 195 reg.mod = fmt.Sprintf("foo.mod%02d", i) 196 fsys, err := txtar.FS(txtar.Parse([]byte(fmt.Sprintf(` 197 -- %s_v0.0.1/cue.mod/module.cue -- 198 module: "%s@v0" 199 language: version: "v0.8.0" 200 -- %s_v0.0.1/bar/bar.cue -- 201 package bar 202 `, reg.mod, reg.mod, reg.mod)))) 203 qt.Assert(t, qt.IsNil(err)) 204 mux := http.NewServeMux() 205 r := ocimem.New() 206 err = registrytest.Upload(context.Background(), r, fsys) 207 qt.Assert(t, qt.IsNil(err)) 208 rh := ociserver.New(r, nil) 209 mux.HandleFunc("/v2/", func(w http.ResponseWriter, r *http.Request) { 210 auth := r.Header.Get("Authorization") 211 if !strings.HasPrefix(auth, fmt.Sprintf("Bearer access_%d_", i)) { 212 w.WriteHeader(401) 213 w.Write([]byte(fmt.Sprintf("server %d: unexpected auth header: %s", i, auth))) 214 return 215 } 216 rh.ServeHTTP(w, r) 217 }) 218 mux.HandleFunc("/login/oauth/token", func(w http.ResponseWriter, r *http.Request) { 219 ctr := atomic.AddInt32(&counter, 1) 220 writeJSON(w, 200, wireToken{ 221 AccessToken: fmt.Sprintf("access_%d_%d", i, ctr), 222 TokenType: "Bearer", 223 RefreshToken: fmt.Sprintf("refresh_%d", ctr), 224 ExpiresIn: 300, 225 }) 226 }) 227 srv := httptest.NewServer(mux) 228 t.Cleanup(srv.Close) 229 u, err := url.Parse(srv.URL) 230 qt.Assert(t, qt.IsNil(err)) 231 reg.host = u.Host 232 } 233 234 expiry := time.Now() 235 logins := &cueconfig.Logins{ 236 Registries: map[string]cueconfig.RegistryLogin{}, 237 } 238 registryConf := "" 239 for i, reg := range registries { 240 logins.Registries[reg.host] = cueconfig.RegistryLogin{ 241 AccessToken: fmt.Sprintf("access_%d_x", i), 242 TokenType: "Bearer", 243 RefreshToken: "refresh_x", 244 Expiry: &expiry, 245 } 246 if registryConf != "" { 247 registryConf += "," 248 } 249 registryConf += fmt.Sprintf("%s=%s+insecure", reg.mod, reg.host) 250 } 251 252 dir := t.TempDir() 253 configDir := filepath.Join(dir, "config") 254 t.Setenv("CUE_CONFIG_DIR", configDir) 255 err := os.MkdirAll(configDir, 0o777) 256 qt.Assert(t, qt.IsNil(err)) 257 258 // Check logins.json validation. 259 logins.Registries["blank"] = cueconfig.RegistryLogin{TokenType: "Bearer"} 260 err = cueconfig.WriteLogins(filepath.Join(configDir, "logins.json"), logins) 261 delete(logins.Registries, "blank") 262 qt.Assert(t, qt.IsNil(err)) 263 _, err = cueconfig.ReadLogins(filepath.Join(configDir, "logins.json")) 264 qt.Assert(t, qt.ErrorMatches(err, "invalid .*logins.json: missing access_token for registry blank")) 265 266 // Check write-read round-trip. 267 err = cueconfig.WriteLogins(filepath.Join(configDir, "logins.json"), logins) 268 qt.Assert(t, qt.IsNil(err)) 269 logins2, err := cueconfig.ReadLogins(filepath.Join(configDir, "logins.json")) 270 qt.Assert(t, qt.IsNil(err)) 271 qt.Assert(t, qt.DeepEquals(logins2, logins)) 272 273 t.Setenv("CUE_REGISTRY", registryConf) 274 cacheDir := filepath.Join(dir, "cache") 275 t.Setenv("CUE_CACHE_DIR", cacheDir) 276 t.Cleanup(func() { 277 modcache.RemoveAll(cacheDir) 278 }) 279 280 r, err := NewRegistry(nil) 281 qt.Assert(t, qt.IsNil(err)) 282 283 g := new(errgroup.Group) 284 for i := range registries { 285 mod := registries[i].mod 286 g.Go(func() error { 287 ctx := context.Background() 288 loc, err := r.Fetch(ctx, module.MustNewVersion(mod+"@v0", "v0.0.1")) 289 if err != nil { 290 return err 291 } 292 data, err := fs.ReadFile(loc.FS, path.Join(loc.Dir, "bar/bar.cue")) 293 if err != nil { 294 return err 295 } 296 if string(data) != "package bar\n" { 297 return fmt.Errorf("unexpected data: %q", string(data)) 298 } 299 return nil 300 }) 301 } 302 err = g.Wait() 303 qt.Assert(t, qt.IsNil(err)) 304 qt.Assert(t, qt.Equals(int(counter), len(registries))) 305 } 306 307 // dockerConfig describes the minimal subset of the docker 308 // configuration file necessary to check that authentication 309 // is correction hooked up. 310 type dockerConfig struct { 311 Auths map[string]authConfig `json:"auths"` 312 } 313 314 // authConfig contains authorization information for connecting to a Registry. 315 type authConfig struct { 316 Username string `json:"username,omitempty"` 317 Password string `json:"password,omitempty"` 318 } 319 320 type roundTripperFunc func(*http.Request) (*http.Response, error) 321 322 func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 323 return f(req) 324 } 325 326 func writeJSON(w http.ResponseWriter, statusCode int, v any) { 327 b, err := json.Marshal(v) 328 if err != nil { 329 // should never happen 330 panic(err) 331 } 332 w.Header().Set("Content-Type", "application/json") 333 w.WriteHeader(statusCode) 334 w.Write(b) 335 } 336 337 // wireToken describes the JSON encoding for an OAuth 2.0 token 338 // as specified in the RFC 6749. 339 type wireToken struct { 340 AccessToken string `json:"access_token,omitempty"` 341 TokenType string `json:"token_type,omitempty"` 342 RefreshToken string `json:"refresh_token,omitempty"` 343 ExpiresIn int `json:"expires_in,omitempty"` 344 }