cuelabs.dev/go/oci/ociregistry@v0.0.0-20240906074133-82eb438dd565/ociclient/client.go (about)

     1  // Copyright 2023 CUE Labs AG
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package ociclient provides an implementation of ociregistry.Interface that
    16  // uses HTTP to talk to the remote registry.
    17  package ociclient
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"hash"
    24  	"io"
    25  	"log"
    26  	"net/http"
    27  	"net/url"
    28  	"strconv"
    29  	"strings"
    30  	"sync/atomic"
    31  
    32  	"github.com/opencontainers/go-digest"
    33  	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
    34  
    35  	"cuelabs.dev/go/oci/ociregistry"
    36  	"cuelabs.dev/go/oci/ociregistry/internal/ocirequest"
    37  	"cuelabs.dev/go/oci/ociregistry/ociauth"
    38  	"cuelabs.dev/go/oci/ociregistry/ociref"
    39  )
    40  
    41  // debug enables logging.
    42  // TODO this should be configurable in the API.
    43  const debug = false
    44  
    45  type Options struct {
    46  	// DebugID is used to prefix any log messages printed by the client.
    47  	DebugID string
    48  
    49  	// Transport is used to make HTTP requests. The context passed
    50  	// to its RoundTrip method will have an appropriate
    51  	// [ociauth.RequestInfo] value added, suitable for consumption
    52  	// by the transport created by [ociauth.NewStdTransport]. If
    53  	// Transport is nil, [http.DefaultTransport] will be used.
    54  	Transport http.RoundTripper
    55  
    56  	// Insecure specifies whether an http scheme will be used to
    57  	// address the host instead of https.
    58  	Insecure bool
    59  
    60  	// ListPageSize configures the maximum number of results
    61  	// requested when making list requests. If it's <= zero, it
    62  	// defaults to DefaultListPageSize.
    63  	ListPageSize int
    64  }
    65  
    66  // See https://github.com/google/go-containerregistry/issues/1091
    67  // for an early report of the issue alluded to below.
    68  
    69  // DefaultListPageSize holds the default number of results
    70  // to request when using the list endpoints.
    71  // It's not more than 1000 because AWS ECR complains
    72  // it it's more than that.
    73  const DefaultListPageSize = 1000
    74  
    75  var debugID int32
    76  
    77  // New returns a registry implementation that uses the OCI
    78  // HTTP API. A nil opts parameter is equivalent to a pointer
    79  // to zero Options.
    80  //
    81  // The host specifies the host name to talk to; it may
    82  // optionally be a host:port pair.
    83  func New(host string, opts0 *Options) (ociregistry.Interface, error) {
    84  	var opts Options
    85  	if opts0 != nil {
    86  		opts = *opts0
    87  	}
    88  	if opts.DebugID == "" {
    89  		opts.DebugID = fmt.Sprintf("id%d", atomic.AddInt32(&debugID, 1))
    90  	}
    91  	if opts.Transport == nil {
    92  		opts.Transport = http.DefaultTransport
    93  	}
    94  	// Check that it's a valid host by forming a URL from it and checking that it matches.
    95  	u, err := url.Parse("https://" + host + "/path")
    96  	if err != nil {
    97  		return nil, fmt.Errorf("invalid host %q", host)
    98  	}
    99  	if u.Host != host {
   100  		return nil, fmt.Errorf("invalid host %q (does not correctly form a host part of a URL)", host)
   101  	}
   102  	if opts.Insecure {
   103  		u.Scheme = "http"
   104  	}
   105  	if opts.ListPageSize == 0 {
   106  		opts.ListPageSize = DefaultListPageSize
   107  	}
   108  	return &client{
   109  		httpHost:   host,
   110  		httpScheme: u.Scheme,
   111  		httpClient: &http.Client{
   112  			Transport: opts.Transport,
   113  		},
   114  		debugID:      opts.DebugID,
   115  		listPageSize: opts.ListPageSize,
   116  	}, nil
   117  }
   118  
   119  type client struct {
   120  	*ociregistry.Funcs
   121  	httpScheme   string
   122  	httpHost     string
   123  	httpClient   *http.Client
   124  	debugID      string
   125  	listPageSize int
   126  }
   127  
   128  type descriptorRequired byte
   129  
   130  const (
   131  	requireSize descriptorRequired = 1 << iota
   132  	requireDigest
   133  )
   134  
   135  // descriptorFromResponse tries to form a descriptor from an HTTP response,
   136  // filling in the Digest field using knownDigest if it's not present.
   137  //
   138  // Note: this implies that the Digest field will be empty if there is no
   139  // digest in the response and knownDigest is empty.
   140  func descriptorFromResponse(resp *http.Response, knownDigest digest.Digest, require descriptorRequired) (ociregistry.Descriptor, error) {
   141  	contentType := resp.Header.Get("Content-Type")
   142  	if contentType == "" {
   143  		contentType = "application/octet-stream"
   144  	}
   145  	size := int64(0)
   146  	if (require & requireSize) != 0 {
   147  		if resp.StatusCode == http.StatusPartialContent {
   148  			contentRange := resp.Header.Get("Content-Range")
   149  			if contentRange == "" {
   150  				return ociregistry.Descriptor{}, fmt.Errorf("no Content-Range in partial content response")
   151  			}
   152  			i := strings.LastIndex(contentRange, "/")
   153  			if i == -1 {
   154  				return ociregistry.Descriptor{}, fmt.Errorf("malformed Content-Range %q", contentRange)
   155  			}
   156  			contentSize, err := strconv.ParseInt(contentRange[i+1:], 10, 64)
   157  			if err != nil {
   158  				return ociregistry.Descriptor{}, fmt.Errorf("malformed Content-Range %q", contentRange)
   159  			}
   160  			size = contentSize
   161  		} else {
   162  			if resp.ContentLength < 0 {
   163  				return ociregistry.Descriptor{}, fmt.Errorf("unknown content length")
   164  			}
   165  			size = resp.ContentLength
   166  		}
   167  	}
   168  	digest := digest.Digest(resp.Header.Get("Docker-Content-Digest"))
   169  	if digest != "" {
   170  		if !ociref.IsValidDigest(string(digest)) {
   171  			return ociregistry.Descriptor{}, fmt.Errorf("bad digest %q found in response", digest)
   172  		}
   173  	} else {
   174  		digest = knownDigest
   175  	}
   176  	if (require&requireDigest) != 0 && digest == "" {
   177  		return ociregistry.Descriptor{}, fmt.Errorf("no digest found in response")
   178  	}
   179  	return ociregistry.Descriptor{
   180  		Digest:    digest,
   181  		MediaType: contentType,
   182  		Size:      size,
   183  	}, nil
   184  }
   185  
   186  func newBlobReader(r io.ReadCloser, desc ociregistry.Descriptor) *blobReader {
   187  	return &blobReader{
   188  		r:        r,
   189  		digester: desc.Digest.Algorithm().Hash(),
   190  		desc:     desc,
   191  		verify:   true,
   192  	}
   193  }
   194  
   195  func newBlobReaderUnverified(r io.ReadCloser, desc ociregistry.Descriptor) *blobReader {
   196  	br := newBlobReader(r, desc)
   197  	br.verify = false
   198  	return br
   199  }
   200  
   201  type blobReader struct {
   202  	r        io.ReadCloser
   203  	n        int64
   204  	digester hash.Hash
   205  	desc     ociregistry.Descriptor
   206  	verify   bool
   207  }
   208  
   209  func (r *blobReader) Descriptor() ociregistry.Descriptor {
   210  	return r.desc
   211  }
   212  
   213  func (r *blobReader) Read(buf []byte) (int, error) {
   214  	n, err := r.r.Read(buf)
   215  	r.n += int64(n)
   216  	r.digester.Write(buf[:n])
   217  	if err == nil {
   218  		if r.n > r.desc.Size {
   219  			// Fail early when the blob is too big; we can do that even
   220  			// when we're not verifying for other use cases.
   221  			return n, fmt.Errorf("blob size exceeds content length %d: %w", r.desc.Size, ociregistry.ErrSizeInvalid)
   222  		}
   223  		return n, nil
   224  	}
   225  	if err != io.EOF {
   226  		return n, err
   227  	}
   228  	if !r.verify {
   229  		return n, io.EOF
   230  	}
   231  	if r.n != r.desc.Size {
   232  		return n, fmt.Errorf("blob size mismatch (%d/%d): %w", r.n, r.desc.Size, ociregistry.ErrSizeInvalid)
   233  	}
   234  	gotDigest := digest.NewDigest(r.desc.Digest.Algorithm(), r.digester)
   235  	if gotDigest != r.desc.Digest {
   236  		return n, fmt.Errorf("digest mismatch when reading blob")
   237  	}
   238  	return n, io.EOF
   239  }
   240  
   241  func (r *blobReader) Close() error {
   242  	return r.r.Close()
   243  }
   244  
   245  // TODO make this list configurable.
   246  var knownManifestMediaTypes = []string{
   247  	ocispec.MediaTypeImageManifest,
   248  	ocispec.MediaTypeImageIndex,
   249  	"application/vnd.oci.artifact.manifest.v1+json", // deprecated.
   250  	"application/vnd.docker.distribution.manifest.v1+json",
   251  	"application/vnd.docker.distribution.manifest.v2+json",
   252  	"application/vnd.docker.distribution.manifest.list.v2+json",
   253  	// Technically this wildcard should be sufficient, but it isn't
   254  	// recognized by some registries.
   255  	"*/*",
   256  }
   257  
   258  // doRequest performs the given OCI request, sending it with the given body (which may be nil).
   259  func (c *client) doRequest(ctx context.Context, rreq *ocirequest.Request, okStatuses ...int) (*http.Response, error) {
   260  	req, err := newRequest(ctx, rreq, nil)
   261  	if err != nil {
   262  		return nil, err
   263  	}
   264  	if rreq.Kind == ocirequest.ReqManifestGet || rreq.Kind == ocirequest.ReqManifestHead {
   265  		// When getting manifests, some servers won't return
   266  		// the content unless there's an Accept header, so
   267  		// add all the manifest kinds that we know about.
   268  		req.Header["Accept"] = knownManifestMediaTypes
   269  	}
   270  	resp, err := c.do(req, okStatuses...)
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	if resp.StatusCode/100 == 2 {
   275  		return resp, nil
   276  	}
   277  	defer resp.Body.Close()
   278  	return nil, makeError(resp)
   279  }
   280  
   281  func (c *client) do(req *http.Request, okStatuses ...int) (*http.Response, error) {
   282  	if req.URL.Scheme == "" {
   283  		req.URL.Scheme = c.httpScheme
   284  	}
   285  	if req.URL.Host == "" {
   286  		req.URL.Host = c.httpHost
   287  	}
   288  	if req.Body != nil {
   289  		// Ensure that the body isn't consumed until the
   290  		// server has responded that it will receive it.
   291  		// This means that we can retry requests even when we've
   292  		// got a consume-once-only io.Reader, such as
   293  		// when pushing blobs.
   294  		req.Header.Set("Expect", "100-continue")
   295  	}
   296  	var buf bytes.Buffer
   297  	if debug {
   298  		fmt.Fprintf(&buf, "client.Do: %s %s {{\n", req.Method, req.URL)
   299  		fmt.Fprintf(&buf, "\tBODY: %#v\n", req.Body)
   300  		for k, v := range req.Header {
   301  			fmt.Fprintf(&buf, "\t%s: %q\n", k, v)
   302  		}
   303  		c.logf("%s", buf.Bytes())
   304  	}
   305  	resp, err := c.httpClient.Do(req)
   306  	if err != nil {
   307  		return nil, fmt.Errorf("cannot do HTTP request: %w", err)
   308  	}
   309  	if debug {
   310  		buf.Reset()
   311  		fmt.Fprintf(&buf, "} -> %s {\n", resp.Status)
   312  		for k, v := range resp.Header {
   313  			fmt.Fprintf(&buf, "\t%s: %q\n", k, v)
   314  		}
   315  		data, _ := io.ReadAll(resp.Body)
   316  		if len(data) > 0 {
   317  			fmt.Fprintf(&buf, "\tBODY: %q\n", data)
   318  		}
   319  		fmt.Fprintf(&buf, "}}\n")
   320  		resp.Body.Close()
   321  		resp.Body = io.NopCloser(bytes.NewReader(data))
   322  		c.logf("%s", buf.Bytes())
   323  	}
   324  	if len(okStatuses) == 0 && resp.StatusCode == http.StatusOK {
   325  		return resp, nil
   326  	}
   327  	for _, status := range okStatuses {
   328  		if resp.StatusCode == status {
   329  			return resp, nil
   330  		}
   331  	}
   332  	defer resp.Body.Close()
   333  	if !isOKStatus(resp.StatusCode) {
   334  		return nil, makeError(resp)
   335  	}
   336  	return nil, unexpectedStatusError(resp.StatusCode)
   337  }
   338  
   339  func (c *client) logf(f string, a ...any) {
   340  	log.Printf("ociclient %s: %s", c.debugID, fmt.Sprintf(f, a...))
   341  }
   342  
   343  func locationFromResponse(resp *http.Response) (*url.URL, error) {
   344  	location := resp.Header.Get("Location")
   345  	if location == "" {
   346  		return nil, fmt.Errorf("no Location found in response")
   347  	}
   348  	u, err := url.Parse(location)
   349  	if err != nil {
   350  		return nil, fmt.Errorf("invalid Location URL found in response")
   351  	}
   352  	return resp.Request.URL.ResolveReference(u), nil
   353  }
   354  
   355  func isOKStatus(code int) bool {
   356  	return code/100 == 2
   357  }
   358  
   359  func closeOnError(err *error, r io.Closer) {
   360  	if *err != nil {
   361  		r.Close()
   362  	}
   363  }
   364  
   365  func unexpectedStatusError(code int) error {
   366  	return fmt.Errorf("unexpected HTTP response code %d", code)
   367  }
   368  
   369  func scopeForRequest(r *ocirequest.Request) ociauth.Scope {
   370  	switch r.Kind {
   371  	case ocirequest.ReqPing:
   372  		return ociauth.Scope{}
   373  	case ocirequest.ReqBlobGet,
   374  		ocirequest.ReqBlobHead,
   375  		ocirequest.ReqManifestGet,
   376  		ocirequest.ReqManifestHead,
   377  		ocirequest.ReqTagsList,
   378  		ocirequest.ReqReferrersList:
   379  		return ociauth.NewScope(ociauth.ResourceScope{
   380  			ResourceType: ociauth.TypeRepository,
   381  			Resource:     r.Repo,
   382  			Action:       ociauth.ActionPull,
   383  		})
   384  	case ocirequest.ReqBlobDelete,
   385  		ocirequest.ReqBlobStartUpload,
   386  		ocirequest.ReqBlobUploadBlob,
   387  		ocirequest.ReqBlobUploadInfo,
   388  		ocirequest.ReqBlobUploadChunk,
   389  		ocirequest.ReqBlobCompleteUpload,
   390  		ocirequest.ReqManifestPut,
   391  		ocirequest.ReqManifestDelete:
   392  		return ociauth.NewScope(ociauth.ResourceScope{
   393  			ResourceType: ociauth.TypeRepository,
   394  			Resource:     r.Repo,
   395  			Action:       ociauth.ActionPush,
   396  		})
   397  	case ocirequest.ReqBlobMount:
   398  		return ociauth.NewScope(ociauth.ResourceScope{
   399  			ResourceType: ociauth.TypeRepository,
   400  			Resource:     r.Repo,
   401  			Action:       ociauth.ActionPush,
   402  		}, ociauth.ResourceScope{
   403  			ResourceType: ociauth.TypeRepository,
   404  			Resource:     r.FromRepo,
   405  			Action:       ociauth.ActionPull,
   406  		})
   407  	case ocirequest.ReqCatalogList:
   408  		return ociauth.NewScope(ociauth.CatalogScope)
   409  	default:
   410  		panic(fmt.Errorf("unexpected request kind %v", r.Kind))
   411  	}
   412  }
   413  
   414  func newRequest(ctx context.Context, rreq *ocirequest.Request, body io.Reader) (*http.Request, error) {
   415  	method, u, err := rreq.Construct()
   416  	if err != nil {
   417  		return nil, err
   418  	}
   419  	ctx = ociauth.ContextWithRequestInfo(ctx, ociauth.RequestInfo{
   420  		RequiredScope: scopeForRequest(rreq),
   421  	})
   422  	return http.NewRequestWithContext(ctx, method, u, body)
   423  }