cuelabs.dev/go/oci/ociregistry@v0.0.0-20240906074133-82eb438dd565/ocifilter/select.go (about)

     1  // Copyright 2023 CUE Labs AG
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ocifilter
    16  
    17  import (
    18  	"context"
    19  	"io"
    20  
    21  	"cuelabs.dev/go/oci/ociregistry"
    22  )
    23  
    24  // AccessKind
    25  type AccessKind int
    26  
    27  const (
    28  	// [ociregistry.Reader] methods.
    29  	AccessRead AccessKind = iota
    30  
    31  	// [ociregistry.Writer] methods.
    32  	AccessWrite
    33  
    34  	// [ociregistry.Deleter] methods.
    35  	AccessDelete
    36  
    37  	// [ociregistry.Lister] methods.
    38  	AccessList
    39  )
    40  
    41  // AccessChecker returns a wrapper for r that invokes check
    42  // to check access before calling an underlying method. Only if check succeeds will
    43  // the underlying method be called.
    44  //
    45  // The check function is invoked with the name of the repository being
    46  // accessed (or "*" for Repositories), and the kind of access required.
    47  // For some methods (e.g. Mount), check might be invoked more than
    48  // once for a given repository.
    49  //
    50  // When invoking the Repositories method, check is invoked for each repository in
    51  // the iteration - the repository will be omitted if check returns an error.
    52  func AccessChecker(r ociregistry.Interface, check func(repoName string, access AccessKind) error) ociregistry.Interface {
    53  	return &accessCheckerRegistry{
    54  		check: check,
    55  		r:     r,
    56  	}
    57  }
    58  
    59  type accessCheckerRegistry struct {
    60  	// Embed Funcs rather than the interface directly so that
    61  	// if new methods are added and selectRegistry isn't updated,
    62  	// we fall back to returning an error rather than passing through the method.
    63  	*ociregistry.Funcs
    64  	check func(repoName string, kind AccessKind) error
    65  	r     ociregistry.Interface
    66  }
    67  
    68  // Select returns a wrapper for r that provides only
    69  // repositories for which allow returns true.
    70  //
    71  // Requests for disallowed repositories will return ErrNameUnknown
    72  // errors on read and ErrDenied on write.
    73  func Select(r ociregistry.Interface, allow func(repoName string) bool) ociregistry.Interface {
    74  	return AccessChecker(r, func(repoName string, access AccessKind) error {
    75  		if allow(repoName) {
    76  			return nil
    77  		}
    78  		if access == AccessWrite {
    79  			return ociregistry.ErrDenied
    80  		}
    81  		if access == AccessList && repoName == "*" {
    82  			return nil
    83  		}
    84  		return ociregistry.ErrNameUnknown
    85  	})
    86  }
    87  
    88  func (r *accessCheckerRegistry) GetBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) {
    89  	if err := r.check(repo, AccessRead); err != nil {
    90  		return nil, err
    91  	}
    92  	return r.r.GetBlob(ctx, repo, digest)
    93  }
    94  
    95  func (r *accessCheckerRegistry) GetBlobRange(ctx context.Context, repo string, digest ociregistry.Digest, offset0, offset1 int64) (ociregistry.BlobReader, error) {
    96  	if err := r.check(repo, AccessRead); err != nil {
    97  		return nil, err
    98  	}
    99  	return r.r.GetBlobRange(ctx, repo, digest, offset0, offset1)
   100  }
   101  
   102  func (r *accessCheckerRegistry) GetManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) {
   103  	if err := r.check(repo, AccessRead); err != nil {
   104  		return nil, err
   105  	}
   106  	return r.r.GetManifest(ctx, repo, digest)
   107  }
   108  
   109  func (r *accessCheckerRegistry) GetTag(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) {
   110  	if err := r.check(repo, AccessRead); err != nil {
   111  		return nil, err
   112  	}
   113  	return r.r.GetTag(ctx, repo, tagName)
   114  }
   115  
   116  func (r *accessCheckerRegistry) ResolveBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
   117  	if err := r.check(repo, AccessRead); err != nil {
   118  		return ociregistry.Descriptor{}, err
   119  	}
   120  	return r.r.ResolveBlob(ctx, repo, digest)
   121  }
   122  
   123  func (r *accessCheckerRegistry) ResolveManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
   124  	if err := r.check(repo, AccessRead); err != nil {
   125  		return ociregistry.Descriptor{}, err
   126  	}
   127  	return r.r.ResolveManifest(ctx, repo, digest)
   128  }
   129  
   130  func (r *accessCheckerRegistry) ResolveTag(ctx context.Context, repo string, tagName string) (ociregistry.Descriptor, error) {
   131  	if err := r.check(repo, AccessRead); err != nil {
   132  		return ociregistry.Descriptor{}, err
   133  	}
   134  	return r.r.ResolveTag(ctx, repo, tagName)
   135  }
   136  
   137  func (r *accessCheckerRegistry) PushBlob(ctx context.Context, repo string, desc ociregistry.Descriptor, rd io.Reader) (ociregistry.Descriptor, error) {
   138  	if err := r.check(repo, AccessWrite); err != nil {
   139  		return ociregistry.Descriptor{}, err
   140  	}
   141  	return r.r.PushBlob(ctx, repo, desc, rd)
   142  }
   143  
   144  func (r *accessCheckerRegistry) PushBlobChunked(ctx context.Context, repo string, chunkSize int) (ociregistry.BlobWriter, error) {
   145  	if err := r.check(repo, AccessWrite); err != nil {
   146  		return nil, err
   147  	}
   148  	return r.r.PushBlobChunked(ctx, repo, chunkSize)
   149  }
   150  
   151  func (r *accessCheckerRegistry) PushBlobChunkedResume(ctx context.Context, repo, id string, offset int64, chunkSize int) (ociregistry.BlobWriter, error) {
   152  	if err := r.check(repo, AccessWrite); err != nil {
   153  		return nil, err
   154  	}
   155  	return r.r.PushBlobChunkedResume(ctx, repo, id, offset, chunkSize)
   156  }
   157  
   158  func (r *accessCheckerRegistry) MountBlob(ctx context.Context, fromRepo, toRepo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
   159  	if err := r.check(fromRepo, AccessRead); err != nil {
   160  		return ociregistry.Descriptor{}, err
   161  	}
   162  	if err := r.check(toRepo, AccessWrite); err != nil {
   163  		return ociregistry.Descriptor{}, err
   164  	}
   165  	return r.r.MountBlob(ctx, fromRepo, toRepo, digest)
   166  }
   167  
   168  func (r *accessCheckerRegistry) PushManifest(ctx context.Context, repo string, tag string, contents []byte, mediaType string) (ociregistry.Descriptor, error) {
   169  	if err := r.check(repo, AccessWrite); err != nil {
   170  		return ociregistry.Descriptor{}, err
   171  	}
   172  	return r.r.PushManifest(ctx, repo, tag, contents, mediaType)
   173  }
   174  
   175  func (r *accessCheckerRegistry) DeleteBlob(ctx context.Context, repo string, digest ociregistry.Digest) error {
   176  	if err := r.check(repo, AccessDelete); err != nil {
   177  		return err
   178  	}
   179  	return r.r.DeleteBlob(ctx, repo, digest)
   180  }
   181  
   182  func (r *accessCheckerRegistry) DeleteManifest(ctx context.Context, repo string, digest ociregistry.Digest) error {
   183  	if err := r.check(repo, AccessDelete); err != nil {
   184  		return err
   185  	}
   186  	return r.r.DeleteManifest(ctx, repo, digest)
   187  }
   188  
   189  func (r *accessCheckerRegistry) DeleteTag(ctx context.Context, repo string, name string) error {
   190  	if err := r.check(repo, AccessDelete); err != nil {
   191  		return err
   192  	}
   193  	return r.r.DeleteTag(ctx, repo, name)
   194  }
   195  
   196  func (r *accessCheckerRegistry) Repositories(ctx context.Context, startAfter string) ociregistry.Seq[string] {
   197  	if err := r.check("*", AccessList); err != nil {
   198  		return ociregistry.ErrorSeq[string](err)
   199  	}
   200  	return func(yield func(string, error) bool) {
   201  		// TODO(go1.23): for name, err := range r.r.Repositories(ctx)
   202  		r.r.Repositories(ctx, startAfter)(func(repo string, err error) bool {
   203  			if err != nil {
   204  				yield("", err)
   205  				return false
   206  			}
   207  			if r.check(repo, AccessRead) != nil {
   208  				return true
   209  			}
   210  			return yield(repo, nil)
   211  		})
   212  	}
   213  }
   214  
   215  func (r *accessCheckerRegistry) Tags(ctx context.Context, repo, startAfter string) ociregistry.Seq[string] {
   216  	if err := r.check(repo, AccessList); err != nil {
   217  		return ociregistry.ErrorSeq[string](err)
   218  	}
   219  	return r.r.Tags(ctx, repo, startAfter)
   220  }
   221  
   222  func (r *accessCheckerRegistry) Referrers(ctx context.Context, repo string, digest ociregistry.Digest, artifactType string) ociregistry.Seq[ociregistry.Descriptor] {
   223  	if err := r.check(repo, AccessList); err != nil {
   224  		return ociregistry.ErrorSeq[ociregistry.Descriptor](err)
   225  	}
   226  	return r.r.Referrers(ctx, repo, digest, artifactType)
   227  }