github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/docker/registry/internal/base_client.go (about)

     1  // Copyright 2021 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package internal
     5  
     6  import (
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"fmt"
    10  	"net/http"
    11  	"net/url"
    12  	"path"
    13  	"regexp"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/docker/distribution/reference"
    18  	"github.com/juju/errors"
    19  	"github.com/juju/loggo"
    20  
    21  	"github.com/juju/juju/docker"
    22  )
    23  
    24  var logger = loggo.GetLogger("juju.docker.registry.internal")
    25  
    26  const (
    27  	defaultTimeout = 15 * time.Second
    28  )
    29  
    30  // APIVersion is the API version type.
    31  type APIVersion string
    32  
    33  const (
    34  	// APIVersionV1 is the API version v1.
    35  	APIVersionV1 APIVersion = "v1"
    36  	// APIVersionV2 is the API version v2.
    37  	APIVersionV2 APIVersion = "v2"
    38  )
    39  
    40  func (v APIVersion) String() string {
    41  	return string(v)
    42  }
    43  
    44  type baseClient struct {
    45  	baseURL     *url.URL
    46  	client      *http.Client
    47  	repoDetails *docker.ImageRepoDetails
    48  }
    49  
    50  func newBase(
    51  	repoDetails docker.ImageRepoDetails, transport http.RoundTripper,
    52  	normalizeRepoDetails func(repoDetails *docker.ImageRepoDetails) error,
    53  ) (*baseClient, error) {
    54  	c := &baseClient{
    55  		baseURL:     &url.URL{},
    56  		repoDetails: &repoDetails,
    57  		client: &http.Client{
    58  			Transport: transport,
    59  			Timeout:   defaultTimeout,
    60  		},
    61  	}
    62  	err := normalizeRepoDetails(c.repoDetails)
    63  	if err != nil {
    64  		return nil, errors.Trace(err)
    65  	}
    66  	return c, nil
    67  }
    68  
    69  // normalizeRepoDetailsCommon pre-processes ImageRepoDetails before Match().
    70  func normalizeRepoDetailsCommon(repoDetails *docker.ImageRepoDetails) error {
    71  	if repoDetails.ServerAddress != "" {
    72  		return nil
    73  	}
    74  	// We have validated the repository in top level.
    75  	// It should not raise errors here.
    76  	named, _ := reference.ParseNormalizedNamed(repoDetails.Repository)
    77  	domain := reference.Domain(named)
    78  	if domain == "docker.io" && !strings.HasPrefix(strings.ToLower(repoDetails.Repository), "docker.io") {
    79  		return fmt.Errorf("oci reference %q must have a domain", repoDetails.Repository)
    80  	}
    81  	if domain != "" {
    82  		repoDetails.ServerAddress = domain
    83  	}
    84  	return nil
    85  }
    86  
    87  func (c *baseClient) String() string {
    88  	return "generic"
    89  }
    90  
    91  // ShouldRefreshAuth checks if the repoDetails should be refreshed.
    92  func (c *baseClient) ShouldRefreshAuth() (bool, time.Duration) {
    93  	return false, time.Duration(0)
    94  }
    95  
    96  // RefreshAuth refreshes the repoDetails.
    97  func (c *baseClient) RefreshAuth() error {
    98  	return nil
    99  }
   100  
   101  // Match checks if the repository details matches current provider format.
   102  func (c *baseClient) Match() bool {
   103  	return false
   104  }
   105  
   106  // APIVersion returns the registry API version to use.
   107  func (c *baseClient) APIVersion() APIVersion {
   108  	return APIVersionV2
   109  }
   110  
   111  // TransportWrapper wraps RoundTripper.
   112  type TransportWrapper func(http.RoundTripper, *docker.ImageRepoDetails) (http.RoundTripper, error)
   113  
   114  func transportCommon(transport http.RoundTripper, repoDetails *docker.ImageRepoDetails) (http.RoundTripper, error) {
   115  	if !repoDetails.TokenAuthConfig.Empty() {
   116  		return nil, errors.NewNotValid(nil,
   117  			fmt.Sprintf(
   118  				`only {"username", "password"} or {"auth"} authorization is supported for registry %q`,
   119  				repoDetails.ServerAddress,
   120  			),
   121  		)
   122  	}
   123  	return newChallengeTransport(
   124  		transport, repoDetails.Username, repoDetails.Password, repoDetails.Auth.Content(),
   125  	), nil
   126  }
   127  
   128  func mergeTransportWrappers(
   129  	transport http.RoundTripper,
   130  	repoDetails *docker.ImageRepoDetails,
   131  	wrappers ...TransportWrapper,
   132  ) (http.RoundTripper, error) {
   133  	var err error
   134  	for _, wrap := range wrappers {
   135  		if transport, err = wrap(transport, repoDetails); err != nil {
   136  			return nil, errors.Trace(err)
   137  		}
   138  	}
   139  	return transport, nil
   140  }
   141  
   142  func wrapErrorTransport(transport http.RoundTripper, repoDetails *docker.ImageRepoDetails) (http.RoundTripper, error) {
   143  	return newErrorTransport(transport), nil
   144  }
   145  
   146  func (c *baseClient) WrapTransport(wrappers ...TransportWrapper) (err error) {
   147  	wrappers = append(wrappers, transportCommon, wrapErrorTransport)
   148  	if c.client.Transport, err = mergeTransportWrappers(c.client.Transport, c.repoDetails, wrappers...); err != nil {
   149  		return errors.Trace(err)
   150  	}
   151  	return nil
   152  }
   153  
   154  func decideBaseURLCommon(version APIVersion, repoDetails *docker.ImageRepoDetails, baseURL *url.URL) error {
   155  	addr := repoDetails.ServerAddress
   156  	if addr == "" {
   157  		return errors.NotValidf("empty server address for %q", repoDetails.Repository)
   158  	}
   159  	url, err := url.Parse(addr)
   160  	if err != nil {
   161  		return errors.Annotatef(err, "parsing server address %q", addr)
   162  	}
   163  	serverAddressURL := *url
   164  	apiVersion := version.String()
   165  	if !strings.Contains(url.Path, "/"+apiVersion) {
   166  		url.Path = path.Join(url.Path, apiVersion)
   167  	}
   168  	if url.Scheme == "" {
   169  		url.Scheme = "https"
   170  	}
   171  	*baseURL = *url
   172  
   173  	serverAddressURL.Scheme = ""
   174  	repoDetails.ServerAddress = serverAddressURL.String()
   175  	logger.Tracef("baseClient repoDetails %s", repoDetails)
   176  	return nil
   177  }
   178  
   179  // DecideBaseURL decides the API url to use.
   180  func (c *baseClient) DecideBaseURL() error {
   181  	return errors.Trace(decideBaseURLCommon(c.APIVersion(), c.repoDetails, c.baseURL))
   182  }
   183  
   184  func commonURLGetter(version APIVersion, url url.URL, pathTemplate string, args ...interface{}) string {
   185  	pathSuffix := fmt.Sprintf(pathTemplate, args...)
   186  	ver := version.String()
   187  	if !strings.HasSuffix(strings.TrimRight(url.Path, "/"), ver) {
   188  		url.Path = path.Join(url.Path, ver)
   189  	}
   190  	if url.Scheme == "" {
   191  		url.Scheme = "https"
   192  	}
   193  	url.Path = path.Join(url.Path, pathSuffix)
   194  	return url.String()
   195  }
   196  
   197  func (c baseClient) url(pathTemplate string, args ...interface{}) string {
   198  	return commonURLGetter(c.APIVersion(), *c.baseURL, pathTemplate, args...)
   199  }
   200  
   201  // Ping pings the baseClient endpoint.
   202  func (c baseClient) Ping() error {
   203  	url := c.url("/")
   204  	logger.Debugf("baseClient ping %q", url)
   205  	resp, err := c.client.Get(url)
   206  	if resp != nil {
   207  		defer resp.Body.Close()
   208  	}
   209  	return errors.Trace(unwrapNetError(err))
   210  }
   211  
   212  func (c baseClient) ImageRepoDetails() (o docker.ImageRepoDetails) {
   213  	if c.repoDetails != nil {
   214  		return *c.repoDetails
   215  	}
   216  	return o
   217  }
   218  
   219  // Close closes the transport used by the client.
   220  func (c *baseClient) Close() error {
   221  	if t, ok := c.client.Transport.(*http.Transport); ok {
   222  		t.CloseIdleConnections()
   223  	}
   224  	return nil
   225  }
   226  
   227  func (c baseClient) getPaginatedJSON(url string, response interface{}) (string, error) {
   228  	resp, err := c.client.Get(url)
   229  	logger.Tracef("getPaginatedJSON for %q, err %v", url, err)
   230  	if err != nil {
   231  		return "", errors.Trace(unwrapNetError(err))
   232  	}
   233  	defer resp.Body.Close()
   234  
   235  	decoder := json.NewDecoder(resp.Body)
   236  	err = decoder.Decode(response)
   237  	if err != nil {
   238  		return "", errors.Trace(err)
   239  	}
   240  	return getNextLink(resp)
   241  }
   242  
   243  var (
   244  	nextLinkRE     = regexp.MustCompile(`^ *<?([^;>]+)>? *(?:;[^;]*)*; *rel="?next"?(?:;.*)?`)
   245  	errNoMorePages = errors.New("no more pages")
   246  )
   247  
   248  func getNextLink(resp *http.Response) (string, error) {
   249  	for _, link := range resp.Header[http.CanonicalHeaderKey("Link")] {
   250  		parts := nextLinkRE.FindStringSubmatch(link)
   251  		if parts != nil {
   252  			return parts[1], nil
   253  		}
   254  	}
   255  	return "", errNoMorePages
   256  }
   257  
   258  // unpackAuthToken returns the unpacked username and password.
   259  func unpackAuthToken(auth string) (username string, password string, err error) {
   260  	content, err := base64.StdEncoding.DecodeString(auth)
   261  	if err != nil {
   262  		return "", "", errors.Annotate(err, "doing base64 decode on the auth token")
   263  	}
   264  	parts := strings.Split(string(content), ":")
   265  	if len(parts) < 2 {
   266  		return "", "", errors.NotValidf("registry auth token")
   267  	}
   268  	return parts[0], parts[1], nil
   269  }