cuelang.org/go@v0.13.0/mod/modregistrytest/registry.go (about) 1 // Package modregistrytest provides helpers for testing packages 2 // which interact with CUE registries. 3 // 4 // WARNING: THIS PACKAGE IS EXPERIMENTAL. 5 // ITS API MAY CHANGE AT ANY TIME. 6 package modregistrytest 7 8 import ( 9 "bytes" 10 "context" 11 "encoding/json" 12 "errors" 13 "fmt" 14 "io" 15 "io/fs" 16 "maps" 17 "net/http" 18 "net/http/httptest" 19 "net/url" 20 "regexp" 21 "slices" 22 "strings" 23 24 "cuelabs.dev/go/oci/ociregistry" 25 "cuelabs.dev/go/oci/ociregistry/ocifilter" 26 "cuelabs.dev/go/oci/ociregistry/ocimem" 27 "cuelabs.dev/go/oci/ociregistry/ociserver" 28 "golang.org/x/tools/txtar" 29 30 "cuelang.org/go/mod/modfile" 31 "cuelang.org/go/mod/modregistry" 32 "cuelang.org/go/mod/module" 33 "cuelang.org/go/mod/modzip" 34 ) 35 36 // AuthConfig specifies authorization requirements for the server. 37 type AuthConfig struct { 38 // Username and Password hold the basic auth credentials. 39 // If UseTokenServer is true, these apply to the token server 40 // rather than to the registry itself. 41 Username string `json:"username"` 42 Password string `json:"password"` 43 44 // BearerToken holds a bearer token to use as auth. 45 // If UseTokenServer is true, this applies to the token server 46 // rather than to the registry itself. 47 BearerToken string `json:"bearerToken"` 48 49 // UseTokenServer starts a token server and directs client 50 // requests to acquire auth tokens from that server. 51 UseTokenServer bool `json:"useTokenServer"` 52 53 // ACL holds the ACL for an authenticated client. 54 // If it's nil, the user is allowed full access. 55 // Note: there's only one ACL because we only 56 // support a single authenticated user. 57 ACL *ACL `json:"acl,omitempty"` 58 59 // Use401InsteadOf403 causes the server to send a 401 60 // response even when the credentials are present and correct. 61 Use401InsteadOf403 bool `json:"always401"` 62 } 63 64 // ACL determines what endpoints an authenticated user can accesse 65 // Both Allow and Deny hold a list of regular expressions that 66 // are matched against an HTTP request formatted as a string: 67 // 68 // METHOD URL_PATH 69 // 70 // For example: 71 // 72 // GET /v2/foo/bar 73 type ACL struct { 74 // Allow holds the list of allowed paths for a user. 75 // If none match, the user is forbidden. 76 Allow []string 77 // Deny holds the list of denied paths for a user. 78 // If any match, the user is forbidden. 79 Deny []string 80 } 81 82 // Upload uploads the modules found inside fsys (stored 83 // in the format described by [New]) to the given registry. 84 func Upload(ctx context.Context, r ociregistry.Interface, fsys fs.FS) error { 85 _, err := upload(ctx, r, fsys) 86 return err 87 } 88 89 func upload(ctx context.Context, r ociregistry.Interface, fsys fs.FS) (authConfig []byte, err error) { 90 client := modregistry.NewClient(r) 91 mods, authConfigData, err := getModules(fsys) 92 if err != nil { 93 return nil, fmt.Errorf("invalid modules: %v", err) 94 } 95 96 if err := pushContent(ctx, client, mods); err != nil { 97 return nil, fmt.Errorf("cannot push modules: %v", err) 98 } 99 return authConfigData, nil 100 } 101 102 // New starts a registry instance that serves modules found inside fsys. 103 // It serves the OCI registry protocol. 104 // If prefix is non-empty, all module paths will be prefixed by that, 105 // separated by a slash (/). 106 // 107 // Each module should be inside a directory named path_vers, where 108 // slashes in path have been replaced with underscores and should 109 // contain a cue.mod/module.cue file holding the module info. 110 // 111 // If there's a file named auth.json in the root directory, 112 // it will cause access to the server to be gated by the 113 // specified authorization. See the [AuthConfig] type for 114 // details. 115 // 116 // The Registry should be closed after use. 117 func New(fsys fs.FS, prefix string) (*Registry, error) { 118 r := ocimem.NewWithConfig(&ocimem.Config{ImmutableTags: true}) 119 120 authConfigData, err := upload(context.Background(), ocifilter.Sub(r, prefix), fsys) 121 if err != nil { 122 return nil, err 123 } 124 var authConfig *AuthConfig 125 if authConfigData != nil { 126 if err := json.Unmarshal(authConfigData, &authConfig); err != nil { 127 return nil, fmt.Errorf("invalid auth.json: %v", err) 128 } 129 } 130 return NewServer(ocifilter.ReadOnly(r), authConfig) 131 } 132 133 // NewServer is like New except that instead of uploading 134 // the contents of a filesystem, it just serves the contents 135 // of the given registry guarded by the given auth configuration. 136 // If auth is nil, no authentication will be required. 137 func NewServer(r ociregistry.Interface, auth *AuthConfig) (*Registry, error) { 138 var tokenSrv *httptest.Server 139 if auth != nil && auth.UseTokenServer { 140 tokenSrv = httptest.NewServer(tokenHandler(auth)) 141 } 142 r, err := authzRegistry(auth, r) 143 if err != nil { 144 return nil, err 145 } 146 srv := httptest.NewServer(®istryHandler{ 147 auth: auth, 148 registry: ociserver.New(r, nil), 149 tokenSrv: tokenSrv, 150 }) 151 u, err := url.Parse(srv.URL) 152 if err != nil { 153 return nil, err 154 } 155 return &Registry{ 156 srv: srv, 157 host: u.Host, 158 tokenSrv: tokenSrv, 159 }, nil 160 } 161 162 // authzRegistry wraps r by checking whether the client has authorization 163 // to read any given repository. 164 func authzRegistry(auth *AuthConfig, r ociregistry.Interface) (ociregistry.Interface, error) { 165 if auth == nil { 166 return r, nil 167 } 168 allow := func(repoName string) bool { 169 return true 170 } 171 if auth.ACL != nil { 172 allowCheck, err := regexpMatcher(auth.ACL.Allow) 173 if err != nil { 174 return nil, fmt.Errorf("invalid allow list: %v", err) 175 } 176 denyCheck, err := regexpMatcher(auth.ACL.Deny) 177 if err != nil { 178 return nil, fmt.Errorf("invalid deny list: %v", err) 179 } 180 allow = func(repoName string) bool { 181 return allowCheck(repoName) && !denyCheck(repoName) 182 } 183 } 184 return ocifilter.AccessChecker(r, func(repoName string, access ocifilter.AccessKind) (_err error) { 185 if !allow(repoName) { 186 if auth.Use401InsteadOf403 { 187 // TODO this response should be associated with a 188 // Www-Authenticate header, but this won't do that. 189 // Given that the ociauth logic _should_ turn 190 // this back into a 403 error again, perhaps 191 // we're OK. 192 return ociregistry.ErrUnauthorized 193 } 194 // TODO should we be a bit more sophisticated and only 195 // return ErrDenied when the repository doesn't exist? 196 return ociregistry.ErrDenied 197 } 198 return nil 199 }), nil 200 } 201 202 func regexpMatcher(patStrs []string) (func(string) bool, error) { 203 pats := make([]*regexp.Regexp, len(patStrs)) 204 for i, s := range patStrs { 205 pat, err := regexp.Compile(s) 206 if err != nil { 207 return nil, fmt.Errorf("invalid regexp in ACL: %v", err) 208 } 209 pats[i] = pat 210 } 211 return func(name string) bool { 212 for _, pat := range pats { 213 if pat.MatchString(name) { 214 return true 215 } 216 } 217 return false 218 }, nil 219 } 220 221 type registryHandler struct { 222 auth *AuthConfig 223 registry http.Handler 224 tokenSrv *httptest.Server 225 } 226 227 const ( 228 registryAuthToken = "ok-token-for-registrytest" 229 registryUnauthToken = "unauth-token-for-registrytest" 230 ) 231 232 func (h *registryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 233 if h.auth == nil { 234 h.registry.ServeHTTP(w, req) 235 return 236 } 237 if h.tokenSrv == nil { 238 h.serveDirectAuth(w, req) 239 return 240 } 241 242 // Auth with token server. 243 wwwAuth := fmt.Sprintf("Bearer realm=%q,service=registrytest", h.tokenSrv.URL) 244 authHeader := req.Header.Get("Authorization") 245 if authHeader == "" { 246 w.Header().Set("Www-Authenticate", wwwAuth) 247 writeError(w, ociregistry.ErrUnauthorized) 248 return 249 } 250 kind, token, ok := strings.Cut(authHeader, " ") 251 if !ok || kind != "Bearer" { 252 w.Header().Set("Www-Authenticate", wwwAuth) 253 writeError(w, ociregistry.ErrUnauthorized) 254 return 255 } 256 switch token { 257 case registryAuthToken: 258 // User is authorized. 259 case registryUnauthToken: 260 writeError(w, ociregistry.ErrDenied) 261 return 262 default: 263 // If we don't recognize the token, then presumably 264 // the client isn't authenticated so it's 401 not 403. 265 w.Header().Set("Www-Authenticate", wwwAuth) 266 writeError(w, ociregistry.ErrUnauthorized) 267 return 268 } 269 // If the underlying registry returns a 401 error, 270 // we need to add the Www-Authenticate header. 271 // As there's no way to get ociserver to do it, 272 // we hack it by wrapping the ResponseWriter 273 // with an implementation that does. 274 h.registry.ServeHTTP(&authHeaderWriter{ 275 wwwAuth: wwwAuth, 276 ResponseWriter: w, 277 }, req) 278 } 279 280 func (h *registryHandler) serveDirectAuth(w http.ResponseWriter, req *http.Request) { 281 auth := req.Header.Get("Authorization") 282 if auth == "" { 283 if h.auth.BearerToken != "" { 284 // Note that this lacks information like the realm, 285 // but we don't need it for our test cases yet. 286 w.Header().Set("Www-Authenticate", "Bearer service=registry") 287 } else { 288 w.Header().Set("Www-Authenticate", "Basic service=registry") 289 } 290 writeError(w, fmt.Errorf("%w: no credentials", ociregistry.ErrUnauthorized)) 291 return 292 } 293 if h.auth.BearerToken != "" { 294 token, ok := strings.CutPrefix(auth, "Bearer ") 295 if !ok || token != h.auth.BearerToken { 296 writeError(w, fmt.Errorf("%w: invalid bearer credentials", ociregistry.ErrUnauthorized)) 297 return 298 } 299 } else { 300 username, password, ok := req.BasicAuth() 301 if !ok || username != h.auth.Username || password != h.auth.Password { 302 writeError(w, fmt.Errorf("%w: invalid user-password credentials", ociregistry.ErrUnauthorized)) 303 return 304 } 305 } 306 h.registry.ServeHTTP(w, req) 307 } 308 309 func tokenHandler(*AuthConfig) http.Handler { 310 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 311 if req.Method != "POST" { 312 http.Error(w, "only POST supported", http.StatusMethodNotAllowed) 313 return 314 } 315 req.ParseForm() 316 if req.Form.Get("service") != "registrytest" { 317 http.Error(w, "invalid service", http.StatusBadRequest) 318 return 319 } 320 if req.Form.Get("grant_type") != "refresh_token" { 321 http.Error(w, "invalid grant type", http.StatusBadRequest) 322 return 323 } 324 refreshToken := req.Form.Get("refresh_token") 325 if refreshToken != "registrytest-refresh" { 326 http.Error(w, fmt.Sprintf("invalid refresh token %q", refreshToken), http.StatusForbidden) 327 return 328 } 329 // See ociauth.wireToken for the full JSON format. 330 data, _ := json.Marshal(map[string]string{ 331 "token": registryAuthToken, 332 }) 333 w.Header().Set("Content-Type", "application/json") 334 w.Write(data) 335 }) 336 } 337 338 func writeError(w http.ResponseWriter, err error) { 339 data, httpStatus := ociregistry.MarshalError(err) 340 w.Header().Set("Content-Type", "application/json") 341 w.WriteHeader(httpStatus) 342 w.Write(data) 343 } 344 345 func pushContent(ctx context.Context, client *modregistry.Client, mods map[module.Version]*moduleContent) error { 346 pushed := make(map[module.Version]bool) 347 // Iterate over modules in deterministic order. 348 for _, v := range slices.SortedFunc(maps.Keys(mods), module.Version.Compare) { 349 err := visitDepthFirst(mods, v, func(v module.Version, m *moduleContent) error { 350 if pushed[v] { 351 return nil 352 } 353 var zipContent bytes.Buffer 354 if err := m.writeZip(&zipContent); err != nil { 355 return err 356 } 357 if err := client.PutModule(ctx, v, bytes.NewReader(zipContent.Bytes()), int64(zipContent.Len())); err != nil { 358 return err 359 } 360 pushed[v] = true 361 return nil 362 }) 363 if err != nil { 364 return err 365 } 366 } 367 return nil 368 } 369 370 func visitDepthFirst(mods map[module.Version]*moduleContent, v module.Version, f func(module.Version, *moduleContent) error) error { 371 m := mods[v] 372 if m == nil { 373 return nil 374 } 375 for _, depv := range m.modFile.DepVersions() { 376 if err := visitDepthFirst(mods, depv, f); err != nil { 377 return err 378 } 379 } 380 return f(v, m) 381 } 382 383 type Registry struct { 384 srv *httptest.Server 385 tokenSrv *httptest.Server 386 host string 387 } 388 389 func (r *Registry) Close() { 390 r.srv.Close() 391 if r.tokenSrv != nil { 392 r.tokenSrv.Close() 393 } 394 } 395 396 type authHeaderWriter struct { 397 wwwAuth string 398 http.ResponseWriter 399 } 400 401 func (w *authHeaderWriter) WriteHeader(code int) { 402 if code == http.StatusUnauthorized && w.Header().Get("Www-Authenticate") == "" { 403 w.Header().Set("Www-Authenticate", w.wwwAuth) 404 } 405 w.ResponseWriter.WriteHeader(code) 406 } 407 408 // Host returns the hostname for the registry server; 409 // for example localhost:13455. 410 // 411 // The connection can be assumed to be insecure. 412 func (r *Registry) Host() string { 413 return r.host 414 } 415 416 func getModules(fsys fs.FS) (map[module.Version]*moduleContent, []byte, error) { 417 var authConfig []byte 418 modules := make(map[string]*moduleContent) 419 if err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { 420 if err != nil { 421 // If a filesystem has no entries at all, 422 // return zero modules without an error. 423 if path == "." && errors.Is(err, fs.ErrNotExist) { 424 return fs.SkipAll 425 } 426 return err 427 } 428 if d.IsDir() { 429 return nil // we're only interested in regular files, not their parent directories 430 } 431 if path == "auth.json" { 432 authConfig, err = fs.ReadFile(fsys, path) 433 if err != nil { 434 return err 435 } 436 return nil 437 } 438 modver, rest, ok := strings.Cut(path, "/") 439 if !ok { 440 return fmt.Errorf("registry should only contain directories, but found regular file %q", path) 441 } 442 content := modules[modver] 443 if content == nil { 444 content = &moduleContent{} 445 modules[modver] = content 446 } 447 data, err := fs.ReadFile(fsys, path) 448 if err != nil { 449 return err 450 } 451 content.files = append(content.files, txtar.File{ 452 Name: rest, 453 Data: data, 454 }) 455 return nil 456 }); err != nil { 457 return nil, nil, err 458 } 459 for modver, content := range modules { 460 if err := content.init(modver); err != nil { 461 return nil, nil, fmt.Errorf("cannot initialize module %q: %v", modver, err) 462 } 463 } 464 byVer := map[module.Version]*moduleContent{} 465 for _, m := range modules { 466 byVer[m.version] = m 467 } 468 return byVer, authConfig, nil 469 } 470 471 type moduleContent struct { 472 version module.Version 473 files []txtar.File 474 modFile *modfile.File 475 } 476 477 func (c *moduleContent) writeZip(w io.Writer) error { 478 return modzip.Create(w, c.version, c.files, txtarFileIO{}) 479 } 480 481 func (c *moduleContent) init(versDir string) error { 482 found := false 483 for _, f := range c.files { 484 if f.Name != "cue.mod/module.cue" { 485 continue 486 } 487 modf, err := modfile.Parse(f.Data, f.Name) 488 if err != nil { 489 return err 490 } 491 if found { 492 return fmt.Errorf("multiple module.cue files") 493 } 494 mod := strings.ReplaceAll(modf.ModulePath(), "/", "_") + "_" 495 vers := strings.TrimPrefix(versDir, mod) 496 if len(vers) == len(versDir) { 497 return fmt.Errorf("module path %q in module.cue does not match directory %q", modf.QualifiedModule(), versDir) 498 } 499 v, err := module.NewVersion(modf.QualifiedModule(), vers) 500 if err != nil { 501 return fmt.Errorf("cannot make module version: %v", err) 502 } 503 c.version = v 504 c.modFile = modf 505 found = true 506 } 507 if !found { 508 return fmt.Errorf("no module.cue file found in %q", versDir) 509 } 510 return nil 511 }