github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/docker/registry/internal/base_client.go (about) 1 // Copyright 2021 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package internal 5 6 import ( 7 "encoding/base64" 8 "encoding/json" 9 "fmt" 10 "net/http" 11 "net/url" 12 "path" 13 "regexp" 14 "strings" 15 "time" 16 17 "github.com/docker/distribution/reference" 18 "github.com/juju/errors" 19 "github.com/juju/loggo" 20 21 "github.com/juju/juju/docker" 22 ) 23 24 var logger = loggo.GetLogger("juju.docker.registry.internal") 25 26 const ( 27 defaultTimeout = 15 * time.Second 28 ) 29 30 // APIVersion is the API version type. 31 type APIVersion string 32 33 const ( 34 // APIVersionV1 is the API version v1. 35 APIVersionV1 APIVersion = "v1" 36 // APIVersionV2 is the API version v2. 37 APIVersionV2 APIVersion = "v2" 38 ) 39 40 func (v APIVersion) String() string { 41 return string(v) 42 } 43 44 type baseClient struct { 45 baseURL *url.URL 46 client *http.Client 47 repoDetails *docker.ImageRepoDetails 48 } 49 50 func newBase( 51 repoDetails docker.ImageRepoDetails, transport http.RoundTripper, 52 normalizeRepoDetails func(repoDetails *docker.ImageRepoDetails) error, 53 ) (*baseClient, error) { 54 c := &baseClient{ 55 baseURL: &url.URL{}, 56 repoDetails: &repoDetails, 57 client: &http.Client{ 58 Transport: transport, 59 Timeout: defaultTimeout, 60 }, 61 } 62 err := normalizeRepoDetails(c.repoDetails) 63 if err != nil { 64 return nil, errors.Trace(err) 65 } 66 return c, nil 67 } 68 69 // normalizeRepoDetailsCommon pre-processes ImageRepoDetails before Match(). 70 func normalizeRepoDetailsCommon(repoDetails *docker.ImageRepoDetails) error { 71 if repoDetails.ServerAddress != "" { 72 return nil 73 } 74 // We have validated the repository in top level. 75 // It should not raise errors here. 76 named, _ := reference.ParseNormalizedNamed(repoDetails.Repository) 77 domain := reference.Domain(named) 78 if domain == "docker.io" && !strings.HasPrefix(strings.ToLower(repoDetails.Repository), "docker.io") { 79 return fmt.Errorf("oci reference %q must have a domain", repoDetails.Repository) 80 } 81 if domain != "" { 82 repoDetails.ServerAddress = domain 83 } 84 return nil 85 } 86 87 func (c *baseClient) String() string { 88 return "generic" 89 } 90 91 // ShouldRefreshAuth checks if the repoDetails should be refreshed. 92 func (c *baseClient) ShouldRefreshAuth() (bool, time.Duration) { 93 return false, time.Duration(0) 94 } 95 96 // RefreshAuth refreshes the repoDetails. 97 func (c *baseClient) RefreshAuth() error { 98 return nil 99 } 100 101 // Match checks if the repository details matches current provider format. 102 func (c *baseClient) Match() bool { 103 return false 104 } 105 106 // APIVersion returns the registry API version to use. 107 func (c *baseClient) APIVersion() APIVersion { 108 return APIVersionV2 109 } 110 111 // TransportWrapper wraps RoundTripper. 112 type TransportWrapper func(http.RoundTripper, *docker.ImageRepoDetails) (http.RoundTripper, error) 113 114 func transportCommon(transport http.RoundTripper, repoDetails *docker.ImageRepoDetails) (http.RoundTripper, error) { 115 if !repoDetails.TokenAuthConfig.Empty() { 116 return nil, errors.NewNotValid(nil, 117 fmt.Sprintf( 118 `only {"username", "password"} or {"auth"} authorization is supported for registry %q`, 119 repoDetails.ServerAddress, 120 ), 121 ) 122 } 123 return newChallengeTransport( 124 transport, repoDetails.Username, repoDetails.Password, repoDetails.Auth.Content(), 125 ), nil 126 } 127 128 func mergeTransportWrappers( 129 transport http.RoundTripper, 130 repoDetails *docker.ImageRepoDetails, 131 wrappers ...TransportWrapper, 132 ) (http.RoundTripper, error) { 133 var err error 134 for _, wrap := range wrappers { 135 if transport, err = wrap(transport, repoDetails); err != nil { 136 return nil, errors.Trace(err) 137 } 138 } 139 return transport, nil 140 } 141 142 func wrapErrorTransport(transport http.RoundTripper, repoDetails *docker.ImageRepoDetails) (http.RoundTripper, error) { 143 return newErrorTransport(transport), nil 144 } 145 146 func (c *baseClient) WrapTransport(wrappers ...TransportWrapper) (err error) { 147 wrappers = append(wrappers, transportCommon, wrapErrorTransport) 148 if c.client.Transport, err = mergeTransportWrappers(c.client.Transport, c.repoDetails, wrappers...); err != nil { 149 return errors.Trace(err) 150 } 151 return nil 152 } 153 154 func decideBaseURLCommon(version APIVersion, repoDetails *docker.ImageRepoDetails, baseURL *url.URL) error { 155 addr := repoDetails.ServerAddress 156 if addr == "" { 157 return errors.NotValidf("empty server address for %q", repoDetails.Repository) 158 } 159 url, err := url.Parse(addr) 160 if err != nil { 161 return errors.Annotatef(err, "parsing server address %q", addr) 162 } 163 serverAddressURL := *url 164 apiVersion := version.String() 165 if !strings.Contains(url.Path, "/"+apiVersion) { 166 url.Path = path.Join(url.Path, apiVersion) 167 } 168 if url.Scheme == "" { 169 url.Scheme = "https" 170 } 171 *baseURL = *url 172 173 serverAddressURL.Scheme = "" 174 repoDetails.ServerAddress = serverAddressURL.String() 175 logger.Tracef("baseClient repoDetails %s", repoDetails) 176 return nil 177 } 178 179 // DecideBaseURL decides the API url to use. 180 func (c *baseClient) DecideBaseURL() error { 181 return errors.Trace(decideBaseURLCommon(c.APIVersion(), c.repoDetails, c.baseURL)) 182 } 183 184 func commonURLGetter(version APIVersion, url url.URL, pathTemplate string, args ...interface{}) string { 185 pathSuffix := fmt.Sprintf(pathTemplate, args...) 186 ver := version.String() 187 if !strings.HasSuffix(strings.TrimRight(url.Path, "/"), ver) { 188 url.Path = path.Join(url.Path, ver) 189 } 190 if url.Scheme == "" { 191 url.Scheme = "https" 192 } 193 url.Path = path.Join(url.Path, pathSuffix) 194 return url.String() 195 } 196 197 func (c baseClient) url(pathTemplate string, args ...interface{}) string { 198 return commonURLGetter(c.APIVersion(), *c.baseURL, pathTemplate, args...) 199 } 200 201 // Ping pings the baseClient endpoint. 202 func (c baseClient) Ping() error { 203 url := c.url("/") 204 logger.Debugf("baseClient ping %q", url) 205 resp, err := c.client.Get(url) 206 if resp != nil { 207 defer resp.Body.Close() 208 } 209 return errors.Trace(unwrapNetError(err)) 210 } 211 212 func (c baseClient) ImageRepoDetails() (o docker.ImageRepoDetails) { 213 if c.repoDetails != nil { 214 return *c.repoDetails 215 } 216 return o 217 } 218 219 // Close closes the transport used by the client. 220 func (c *baseClient) Close() error { 221 if t, ok := c.client.Transport.(*http.Transport); ok { 222 t.CloseIdleConnections() 223 } 224 return nil 225 } 226 227 func (c baseClient) getPaginatedJSON(url string, response interface{}) (string, error) { 228 resp, err := c.client.Get(url) 229 logger.Tracef("getPaginatedJSON for %q, err %v", url, err) 230 if err != nil { 231 return "", errors.Trace(unwrapNetError(err)) 232 } 233 defer resp.Body.Close() 234 235 decoder := json.NewDecoder(resp.Body) 236 err = decoder.Decode(response) 237 if err != nil { 238 return "", errors.Trace(err) 239 } 240 return getNextLink(resp) 241 } 242 243 var ( 244 nextLinkRE = regexp.MustCompile(`^ *<?([^;>]+)>? *(?:;[^;]*)*; *rel="?next"?(?:;.*)?`) 245 errNoMorePages = errors.New("no more pages") 246 ) 247 248 func getNextLink(resp *http.Response) (string, error) { 249 for _, link := range resp.Header[http.CanonicalHeaderKey("Link")] { 250 parts := nextLinkRE.FindStringSubmatch(link) 251 if parts != nil { 252 return parts[1], nil 253 } 254 } 255 return "", errNoMorePages 256 } 257 258 // unpackAuthToken returns the unpacked username and password. 259 func unpackAuthToken(auth string) (username string, password string, err error) { 260 content, err := base64.StdEncoding.DecodeString(auth) 261 if err != nil { 262 return "", "", errors.Annotate(err, "doing base64 decode on the auth token") 263 } 264 parts := strings.Split(string(content), ":") 265 if len(parts) < 2 { 266 return "", "", errors.NotValidf("registry auth token") 267 } 268 return parts[0], parts[1], nil 269 }