github.com/opcr-io/oras-go/v2@v2.0.0-20231122155130-eb4260d8a0ae/registry/remote/auth/client.go (about) 1 /* 2 Copyright The ORAS Authors. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 16 // Package auth provides authentication for a client to a remote registry. 17 package auth 18 19 import ( 20 "context" 21 "encoding/base64" 22 "encoding/json" 23 "errors" 24 "fmt" 25 "io" 26 "net/http" 27 "net/url" 28 "strings" 29 30 "github.com/opcr-io/oras-go/v2/registry/remote/internal/errutil" 31 "github.com/opcr-io/oras-go/v2/registry/remote/retry" 32 ) 33 34 // DefaultClient is the default auth-decorated client. 35 var DefaultClient = &Client{ 36 Client: retry.DefaultClient, 37 Header: http.Header{ 38 "User-Agent": {"oras-go"}, 39 }, 40 Cache: DefaultCache, 41 } 42 43 // maxResponseBytes specifies the default limit on how many response bytes are 44 // allowed in the server's response from authorization service servers. 45 // A typical response message from authorization service servers is around 1 to 46 // 4 KiB. Since the size of a token must be smaller than the HTTP header size 47 // limit, which is usually 16 KiB. As specified by the distribution, the 48 // response may contain 2 identical tokens, that is, 16 x 2 = 32 KiB. 49 // Hence, 128 KiB should be sufficient. 50 // References: https://docs.docker.com/registry/spec/auth/token/ 51 var maxResponseBytes int64 = 128 * 1024 // 128 KiB 52 53 // defaultClientID specifies the default client ID used in OAuth2. 54 // See also ClientID. 55 var defaultClientID = "oras-go" 56 57 // StaticCredential specifies static credentials for the given host. 58 func StaticCredential(registry string, cred Credential) func(context.Context, string) (Credential, error) { 59 return func(_ context.Context, target string) (Credential, error) { 60 if target == registry { 61 return cred, nil 62 } 63 return EmptyCredential, nil 64 } 65 } 66 67 // Client is an auth-decorated HTTP client. 68 // Its zero value is a usable client that uses http.DefaultClient with no cache. 69 type Client struct { 70 // Client is the underlying HTTP client used to access the remote 71 // server. 72 // If nil, http.DefaultClient is used. 73 // It is possible to use the default retry client from the package 74 // `github.com/opcr-io/oras-go/v2/registry/remote/retry`. That client is already available 75 // in the DefaultClient. 76 // It is also possible to use a custom client. For example, github.com/hashicorp/go-retryablehttp 77 // is a popular HTTP client that supports retries. 78 Client *http.Client 79 80 // Header contains the custom headers to be added to each request. 81 Header http.Header 82 83 // Credential specifies the function for resolving the credential for the 84 // given registry (i.e. host:port). 85 // `EmptyCredential` is a valid return value and should not be considered as 86 // an error. 87 // If nil, the credential is always resolved to `EmptyCredential`. 88 Credential func(context.Context, string) (Credential, error) 89 90 // Cache caches credentials for direct accessing the remote registry. 91 // If nil, no cache is used. 92 Cache Cache 93 94 // ClientID used in fetching OAuth2 token as a required field. 95 // If empty, a default client ID is used. 96 // Reference: https://docs.docker.com/registry/spec/auth/oauth/#getting-a-token 97 ClientID string 98 99 // ForceAttemptOAuth2 controls whether to follow OAuth2 with password grant 100 // instead the distribution spec when authenticating using username and 101 // password. 102 // References: 103 // - https://docs.docker.com/registry/spec/auth/jwt/ 104 // - https://docs.docker.com/registry/spec/auth/oauth/ 105 ForceAttemptOAuth2 bool 106 } 107 108 // client returns an HTTP client used to access the remote registry. 109 // http.DefaultClient is return if the client is not configured. 110 func (c *Client) client() *http.Client { 111 if c.Client == nil { 112 return http.DefaultClient 113 } 114 return c.Client 115 } 116 117 // send adds headers to the request and sends the request to the remote server. 118 func (c *Client) send(req *http.Request) (*http.Response, error) { 119 for key, values := range c.Header { 120 req.Header[key] = append(req.Header[key], values...) 121 } 122 return c.client().Do(req) 123 } 124 125 // credential resolves the credential for the given registry. 126 func (c *Client) credential(ctx context.Context, reg string) (Credential, error) { 127 if c.Credential == nil { 128 return EmptyCredential, nil 129 } 130 return c.Credential(ctx, reg) 131 } 132 133 // cache resolves the cache. 134 // noCache is return if the cache is not configured. 135 func (c *Client) cache() Cache { 136 if c.Cache == nil { 137 return noCache{} 138 } 139 return c.Cache 140 } 141 142 // SetUserAgent sets the user agent for all out-going requests. 143 func (c *Client) SetUserAgent(userAgent string) { 144 if c.Header == nil { 145 c.Header = http.Header{} 146 } 147 c.Header.Set("User-Agent", userAgent) 148 } 149 150 // Do sends the request to the remote server, attempting to resolve 151 // authentication if 'Authorization' header is not set. 152 // 153 // On authentication failure due to bad credential, 154 // - Do returns error if it fails to fetch token for bearer auth. 155 // - Do returns the registry response without error for basic auth. 156 func (c *Client) Do(originalReq *http.Request) (*http.Response, error) { 157 if auth := originalReq.Header.Get("Authorization"); auth != "" { 158 return c.send(originalReq) 159 } 160 161 ctx := originalReq.Context() 162 req := originalReq.Clone(ctx) 163 164 // attempt cached auth token 165 var attemptedKey string 166 cache := c.cache() 167 registry := originalReq.Host 168 scheme, err := cache.GetScheme(ctx, registry) 169 if err == nil { 170 switch scheme { 171 case SchemeBasic: 172 token, err := cache.GetToken(ctx, registry, SchemeBasic, "") 173 if err == nil { 174 req.Header.Set("Authorization", "Basic "+token) 175 } 176 case SchemeBearer: 177 scopes := GetScopes(ctx) 178 attemptedKey = strings.Join(scopes, " ") 179 token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey) 180 if err == nil { 181 req.Header.Set("Authorization", "Bearer "+token) 182 } 183 } 184 } 185 186 resp, err := c.send(req) 187 if err != nil { 188 return nil, err 189 } 190 if resp.StatusCode != http.StatusUnauthorized { 191 return resp, nil 192 } 193 194 // attempt again with credentials for recognized schemes 195 challenge := resp.Header.Get("Www-Authenticate") 196 scheme, params := parseChallenge(challenge) 197 switch scheme { 198 case SchemeBasic: 199 resp.Body.Close() 200 201 token, err := cache.Set(ctx, registry, SchemeBasic, "", func(ctx context.Context) (string, error) { 202 return c.fetchBasicAuth(ctx, registry) 203 }) 204 if err != nil { 205 return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) 206 } 207 208 req = originalReq.Clone(ctx) 209 req.Header.Set("Authorization", "Basic "+token) 210 case SchemeBearer: 211 resp.Body.Close() 212 213 // merge hinted scopes with challenged scopes 214 scopes := GetScopes(ctx) 215 if scope := params["scope"]; scope != "" { 216 scopes = append(scopes, strings.Split(scope, " ")...) 217 scopes = CleanScopes(scopes) 218 } 219 key := strings.Join(scopes, " ") 220 221 // attempt the cache again if there is a scope change 222 if key != attemptedKey { 223 if token, err := cache.GetToken(ctx, registry, SchemeBearer, key); err == nil { 224 req = originalReq.Clone(ctx) 225 req.Header.Set("Authorization", "Bearer "+token) 226 if err := rewindRequestBody(req); err != nil { 227 return nil, err 228 } 229 230 resp, err := c.send(req) 231 if err != nil { 232 return nil, err 233 } 234 if resp.StatusCode != http.StatusUnauthorized { 235 return resp, nil 236 } 237 resp.Body.Close() 238 } 239 } 240 241 // attempt with credentials 242 realm := params["realm"] 243 service := params["service"] 244 token, err := cache.Set(ctx, registry, SchemeBearer, key, func(ctx context.Context) (string, error) { 245 return c.fetchBearerToken(ctx, registry, realm, service, scopes) 246 }) 247 if err != nil { 248 return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err) 249 } 250 251 req = originalReq.Clone(ctx) 252 req.Header.Set("Authorization", "Bearer "+token) 253 default: 254 return resp, nil 255 } 256 if err := rewindRequestBody(req); err != nil { 257 return nil, err 258 } 259 260 return c.send(req) 261 } 262 263 // fetchBasicAuth fetches a basic auth token for the basic challenge. 264 func (c *Client) fetchBasicAuth(ctx context.Context, registry string) (string, error) { 265 cred, err := c.credential(ctx, registry) 266 if err != nil { 267 return "", fmt.Errorf("failed to resolve credential: %w", err) 268 } 269 if cred == EmptyCredential { 270 return "", errors.New("credential required for basic auth") 271 } 272 if cred.Username == "" || cred.Password == "" { 273 return "", errors.New("missing username or password for basic auth") 274 } 275 auth := cred.Username + ":" + cred.Password 276 return base64.StdEncoding.EncodeToString([]byte(auth)), nil 277 } 278 279 // fetchBearerToken fetches an access token for the bearer challenge. 280 func (c *Client) fetchBearerToken(ctx context.Context, registry, realm, service string, scopes []string) (string, error) { 281 cred, err := c.credential(ctx, registry) 282 if err != nil { 283 return "", err 284 } 285 if cred.AccessToken != "" { 286 return cred.AccessToken, nil 287 } 288 if cred == EmptyCredential || (cred.RefreshToken == "" && !c.ForceAttemptOAuth2) { 289 return c.fetchDistributionToken(ctx, realm, service, scopes, cred.Username, cred.Password) 290 } 291 return c.fetchOAuth2Token(ctx, realm, service, scopes, cred) 292 } 293 294 // fetchDistributionToken fetches an access token as defined by the distribution 295 // specification. 296 // It fetches anonymous tokens if no credential is provided. 297 // References: 298 // - https://docs.docker.com/registry/spec/auth/jwt/ 299 // - https://docs.docker.com/registry/spec/auth/token/ 300 func (c *Client) fetchDistributionToken(ctx context.Context, realm, service string, scopes []string, username, password string) (string, error) { 301 req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil) 302 if err != nil { 303 return "", err 304 } 305 if username != "" || password != "" { 306 req.SetBasicAuth(username, password) 307 } 308 q := req.URL.Query() 309 if service != "" { 310 q.Set("service", service) 311 } 312 for _, scope := range scopes { 313 q.Add("scope", scope) 314 } 315 req.URL.RawQuery = q.Encode() 316 317 resp, err := c.send(req) 318 if err != nil { 319 return "", err 320 } 321 defer resp.Body.Close() 322 if resp.StatusCode != http.StatusOK { 323 return "", errutil.ParseErrorResponse(resp) 324 } 325 326 // As specified in https://docs.docker.com/registry/spec/auth/token/ section 327 // "Token Response Fields", the token is either in `token` or 328 // `access_token`. If both present, they are identical. 329 var result struct { 330 Token string `json:"token"` 331 AccessToken string `json:"access_token"` 332 } 333 lr := io.LimitReader(resp.Body, maxResponseBytes) 334 if err := json.NewDecoder(lr).Decode(&result); err != nil { 335 return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err) 336 } 337 if result.AccessToken != "" { 338 return result.AccessToken, nil 339 } 340 if result.Token != "" { 341 return result.Token, nil 342 } 343 return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL) 344 } 345 346 // fetchOAuth2Token fetches an OAuth2 access token. 347 // Reference: https://docs.docker.com/registry/spec/auth/oauth/ 348 func (c *Client) fetchOAuth2Token(ctx context.Context, realm, service string, scopes []string, cred Credential) (string, error) { 349 form := url.Values{} 350 if cred.RefreshToken != "" { 351 form.Set("grant_type", "refresh_token") 352 form.Set("refresh_token", cred.RefreshToken) 353 } else if cred.Username != "" && cred.Password != "" { 354 form.Set("grant_type", "password") 355 form.Set("username", cred.Username) 356 form.Set("password", cred.Password) 357 } else { 358 return "", errors.New("missing username or password for bearer auth") 359 } 360 form.Set("service", service) 361 clientID := c.ClientID 362 if clientID == "" { 363 clientID = defaultClientID 364 } 365 form.Set("client_id", clientID) 366 if len(scopes) != 0 { 367 form.Set("scope", strings.Join(scopes, " ")) 368 } 369 body := strings.NewReader(form.Encode()) 370 371 req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, body) 372 if err != nil { 373 return "", err 374 } 375 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 376 377 resp, err := c.send(req) 378 if err != nil { 379 return "", err 380 } 381 defer resp.Body.Close() 382 if resp.StatusCode != http.StatusOK { 383 return "", errutil.ParseErrorResponse(resp) 384 } 385 386 var result struct { 387 AccessToken string `json:"access_token"` 388 } 389 lr := io.LimitReader(resp.Body, maxResponseBytes) 390 if err := json.NewDecoder(lr).Decode(&result); err != nil { 391 return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err) 392 } 393 if result.AccessToken != "" { 394 return result.AccessToken, nil 395 } 396 return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL) 397 } 398 399 // rewindRequestBody tries to rewind the request body if exists. 400 func rewindRequestBody(req *http.Request) error { 401 if req.Body == nil || req.Body == http.NoBody { 402 return nil 403 } 404 if req.GetBody == nil { 405 return fmt.Errorf("%s %q: request body is not rewindable", req.Method, req.URL) 406 } 407 body, err := req.GetBody() 408 if err != nil { 409 return fmt.Errorf("%s %q: failed to get request body: %w", req.Method, req.URL, err) 410 } 411 req.Body = body 412 return nil 413 }