golang.org/x/oauth2@v0.18.0/google/externalaccount/aws.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package externalaccount 6 7 import ( 8 "bytes" 9 "context" 10 "crypto/hmac" 11 "crypto/sha256" 12 "encoding/hex" 13 "encoding/json" 14 "errors" 15 "fmt" 16 "io" 17 "io/ioutil" 18 "net/http" 19 "net/url" 20 "os" 21 "path" 22 "sort" 23 "strings" 24 "time" 25 26 "golang.org/x/oauth2" 27 ) 28 29 // AwsSecurityCredentials models AWS security credentials. 30 type AwsSecurityCredentials struct { 31 // AccessKeyId is the AWS Access Key ID - Required. 32 AccessKeyID string `json:"AccessKeyID"` 33 // SecretAccessKey is the AWS Secret Access Key - Required. 34 SecretAccessKey string `json:"SecretAccessKey"` 35 // SessionToken is the AWS Session token. This should be provided for temporary AWS security credentials - Optional. 36 SessionToken string `json:"Token"` 37 } 38 39 // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature. 40 type awsRequestSigner struct { 41 RegionName string 42 AwsSecurityCredentials *AwsSecurityCredentials 43 } 44 45 // getenv aliases os.Getenv for testing 46 var getenv = os.Getenv 47 48 const ( 49 defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" 50 51 // AWS Signature Version 4 signing algorithm identifier. 52 awsAlgorithm = "AWS4-HMAC-SHA256" 53 54 // The termination string for the AWS credential scope value as defined in 55 // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html 56 awsRequestType = "aws4_request" 57 58 // The AWS authorization header name for the security session token if available. 59 awsSecurityTokenHeader = "x-amz-security-token" 60 61 // The name of the header containing the session token for metadata endpoint calls 62 awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token" 63 64 awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" 65 66 awsIMDSv2SessionTtl = "300" 67 68 // The AWS authorization header name for the auto-generated date. 69 awsDateHeader = "x-amz-date" 70 71 // Supported AWS configuration environment variables. 72 awsAccessKeyId = "AWS_ACCESS_KEY_ID" 73 awsDefaultRegion = "AWS_DEFAULT_REGION" 74 awsRegion = "AWS_REGION" 75 awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY" 76 awsSessionToken = "AWS_SESSION_TOKEN" 77 78 awsTimeFormatLong = "20060102T150405Z" 79 awsTimeFormatShort = "20060102" 80 ) 81 82 func getSha256(input []byte) (string, error) { 83 hash := sha256.New() 84 if _, err := hash.Write(input); err != nil { 85 return "", err 86 } 87 return hex.EncodeToString(hash.Sum(nil)), nil 88 } 89 90 func getHmacSha256(key, input []byte) ([]byte, error) { 91 hash := hmac.New(sha256.New, key) 92 if _, err := hash.Write(input); err != nil { 93 return nil, err 94 } 95 return hash.Sum(nil), nil 96 } 97 98 func cloneRequest(r *http.Request) *http.Request { 99 r2 := new(http.Request) 100 *r2 = *r 101 if r.Header != nil { 102 r2.Header = make(http.Header, len(r.Header)) 103 104 // Find total number of values. 105 headerCount := 0 106 for _, headerValues := range r.Header { 107 headerCount += len(headerValues) 108 } 109 copiedHeaders := make([]string, headerCount) // shared backing array for headers' values 110 111 for headerKey, headerValues := range r.Header { 112 headerCount = copy(copiedHeaders, headerValues) 113 r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount] 114 copiedHeaders = copiedHeaders[headerCount:] 115 } 116 } 117 return r2 118 } 119 120 func canonicalPath(req *http.Request) string { 121 result := req.URL.EscapedPath() 122 if result == "" { 123 return "/" 124 } 125 return path.Clean(result) 126 } 127 128 func canonicalQuery(req *http.Request) string { 129 queryValues := req.URL.Query() 130 for queryKey := range queryValues { 131 sort.Strings(queryValues[queryKey]) 132 } 133 return queryValues.Encode() 134 } 135 136 func canonicalHeaders(req *http.Request) (string, string) { 137 // Header keys need to be sorted alphabetically. 138 var headers []string 139 lowerCaseHeaders := make(http.Header) 140 for k, v := range req.Header { 141 k := strings.ToLower(k) 142 if _, ok := lowerCaseHeaders[k]; ok { 143 // include additional values 144 lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...) 145 } else { 146 headers = append(headers, k) 147 lowerCaseHeaders[k] = v 148 } 149 } 150 sort.Strings(headers) 151 152 var fullHeaders bytes.Buffer 153 for _, header := range headers { 154 headerValue := strings.Join(lowerCaseHeaders[header], ",") 155 fullHeaders.WriteString(header) 156 fullHeaders.WriteRune(':') 157 fullHeaders.WriteString(headerValue) 158 fullHeaders.WriteRune('\n') 159 } 160 161 return strings.Join(headers, ";"), fullHeaders.String() 162 } 163 164 func requestDataHash(req *http.Request) (string, error) { 165 var requestData []byte 166 if req.Body != nil { 167 requestBody, err := req.GetBody() 168 if err != nil { 169 return "", err 170 } 171 defer requestBody.Close() 172 173 requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20)) 174 if err != nil { 175 return "", err 176 } 177 } 178 179 return getSha256(requestData) 180 } 181 182 func requestHost(req *http.Request) string { 183 if req.Host != "" { 184 return req.Host 185 } 186 return req.URL.Host 187 } 188 189 func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) { 190 dataHash, err := requestDataHash(req) 191 if err != nil { 192 return "", err 193 } 194 195 return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil 196 } 197 198 // SignRequest adds the appropriate headers to an http.Request 199 // or returns an error if something prevented this. 200 func (rs *awsRequestSigner) SignRequest(req *http.Request) error { 201 signedRequest := cloneRequest(req) 202 timestamp := now() 203 204 signedRequest.Header.Add("host", requestHost(req)) 205 206 if rs.AwsSecurityCredentials.SessionToken != "" { 207 signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SessionToken) 208 } 209 210 if signedRequest.Header.Get("date") == "" { 211 signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong)) 212 } 213 214 authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp) 215 if err != nil { 216 return err 217 } 218 signedRequest.Header.Set("Authorization", authorizationCode) 219 220 req.Header = signedRequest.Header 221 return nil 222 } 223 224 func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) { 225 canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req) 226 227 dateStamp := timestamp.Format(awsTimeFormatShort) 228 serviceName := "" 229 if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 { 230 serviceName = splitHost[0] 231 } 232 233 credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType) 234 235 requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData) 236 if err != nil { 237 return "", err 238 } 239 requestHash, err := getSha256([]byte(requestString)) 240 if err != nil { 241 return "", err 242 } 243 244 stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash) 245 246 signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey) 247 for _, signingInput := range []string{ 248 dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign, 249 } { 250 signingKey, err = getHmacSha256(signingKey, []byte(signingInput)) 251 if err != nil { 252 return "", err 253 } 254 } 255 256 return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil 257 } 258 259 type awsCredentialSource struct { 260 environmentID string 261 regionURL string 262 regionalCredVerificationURL string 263 credVerificationURL string 264 imdsv2SessionTokenURL string 265 targetResource string 266 requestSigner *awsRequestSigner 267 region string 268 ctx context.Context 269 client *http.Client 270 awsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier 271 supplierOptions SupplierOptions 272 } 273 274 type awsRequestHeader struct { 275 Key string `json:"key"` 276 Value string `json:"value"` 277 } 278 279 type awsRequest struct { 280 URL string `json:"url"` 281 Method string `json:"method"` 282 Headers []awsRequestHeader `json:"headers"` 283 } 284 285 func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { 286 if cs.client == nil { 287 cs.client = oauth2.NewClient(cs.ctx, nil) 288 } 289 return cs.client.Do(req.WithContext(cs.ctx)) 290 } 291 292 func canRetrieveRegionFromEnvironment() bool { 293 // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is 294 // required. 295 return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != "" 296 } 297 298 func canRetrieveSecurityCredentialFromEnvironment() bool { 299 // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available. 300 return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != "" 301 } 302 303 func (cs awsCredentialSource) shouldUseMetadataServer() bool { 304 return cs.awsSecurityCredentialsSupplier == nil && (!canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment()) 305 } 306 307 func (cs awsCredentialSource) credentialSourceType() string { 308 if cs.awsSecurityCredentialsSupplier != nil { 309 return "programmatic" 310 } 311 return "aws" 312 } 313 314 func (cs awsCredentialSource) subjectToken() (string, error) { 315 // Set Defaults 316 if cs.regionalCredVerificationURL == "" { 317 cs.regionalCredVerificationURL = defaultRegionalCredentialVerificationUrl 318 } 319 if cs.requestSigner == nil { 320 headers := make(map[string]string) 321 if cs.shouldUseMetadataServer() { 322 awsSessionToken, err := cs.getAWSSessionToken() 323 if err != nil { 324 return "", err 325 } 326 327 if awsSessionToken != "" { 328 headers[awsIMDSv2SessionTokenHeader] = awsSessionToken 329 } 330 } 331 332 awsSecurityCredentials, err := cs.getSecurityCredentials(headers) 333 if err != nil { 334 return "", err 335 } 336 cs.region, err = cs.getRegion(headers) 337 if err != nil { 338 return "", err 339 } 340 341 cs.requestSigner = &awsRequestSigner{ 342 RegionName: cs.region, 343 AwsSecurityCredentials: awsSecurityCredentials, 344 } 345 } 346 347 // Generate the signed request to AWS STS GetCallerIdentity API. 348 // Use the required regional endpoint. Otherwise, the request will fail. 349 req, err := http.NewRequest("POST", strings.Replace(cs.regionalCredVerificationURL, "{region}", cs.region, 1), nil) 350 if err != nil { 351 return "", err 352 } 353 // The full, canonical resource name of the workload identity pool 354 // provider, with or without the HTTPS prefix. 355 // Including this header as part of the signature is recommended to 356 // ensure data integrity. 357 if cs.targetResource != "" { 358 req.Header.Add("x-goog-cloud-target-resource", cs.targetResource) 359 } 360 cs.requestSigner.SignRequest(req) 361 362 /* 363 The GCP STS endpoint expects the headers to be formatted as: 364 # [ 365 # {key: 'x-amz-date', value: '...'}, 366 # {key: 'Authorization', value: '...'}, 367 # ... 368 # ] 369 # And then serialized as: 370 # quote(json.dumps({ 371 # url: '...', 372 # method: 'POST', 373 # headers: [{key: 'x-amz-date', value: '...'}, ...] 374 # })) 375 */ 376 377 awsSignedReq := awsRequest{ 378 URL: req.URL.String(), 379 Method: "POST", 380 } 381 for headerKey, headerList := range req.Header { 382 for _, headerValue := range headerList { 383 awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{ 384 Key: headerKey, 385 Value: headerValue, 386 }) 387 } 388 } 389 sort.Slice(awsSignedReq.Headers, func(i, j int) bool { 390 headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key) 391 if headerCompare == 0 { 392 return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0 393 } 394 return headerCompare < 0 395 }) 396 397 result, err := json.Marshal(awsSignedReq) 398 if err != nil { 399 return "", err 400 } 401 return url.QueryEscape(string(result)), nil 402 } 403 404 func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { 405 if cs.imdsv2SessionTokenURL == "" { 406 return "", nil 407 } 408 409 req, err := http.NewRequest("PUT", cs.imdsv2SessionTokenURL, nil) 410 if err != nil { 411 return "", err 412 } 413 414 req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl) 415 416 resp, err := cs.doRequest(req) 417 if err != nil { 418 return "", err 419 } 420 defer resp.Body.Close() 421 422 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) 423 if err != nil { 424 return "", err 425 } 426 427 if resp.StatusCode != 200 { 428 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS session token - %s", string(respBody)) 429 } 430 431 return string(respBody), nil 432 } 433 434 func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { 435 if cs.awsSecurityCredentialsSupplier != nil { 436 return cs.awsSecurityCredentialsSupplier.AwsRegion(cs.ctx, cs.supplierOptions) 437 } 438 if canRetrieveRegionFromEnvironment() { 439 if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { 440 cs.region = envAwsRegion 441 return envAwsRegion, nil 442 } 443 return getenv("AWS_DEFAULT_REGION"), nil 444 } 445 446 if cs.regionURL == "" { 447 return "", errors.New("oauth2/google/externalaccount: unable to determine AWS region") 448 } 449 450 req, err := http.NewRequest("GET", cs.regionURL, nil) 451 if err != nil { 452 return "", err 453 } 454 455 for name, value := range headers { 456 req.Header.Add(name, value) 457 } 458 459 resp, err := cs.doRequest(req) 460 if err != nil { 461 return "", err 462 } 463 defer resp.Body.Close() 464 465 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) 466 if err != nil { 467 return "", err 468 } 469 470 if resp.StatusCode != 200 { 471 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS region - %s", string(respBody)) 472 } 473 474 // This endpoint will return the region in format: us-east-2b. 475 // Only the us-east-2 part should be used. 476 respBodyEnd := 0 477 if len(respBody) > 1 { 478 respBodyEnd = len(respBody) - 1 479 } 480 return string(respBody[:respBodyEnd]), nil 481 } 482 483 func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result *AwsSecurityCredentials, err error) { 484 if cs.awsSecurityCredentialsSupplier != nil { 485 return cs.awsSecurityCredentialsSupplier.AwsSecurityCredentials(cs.ctx, cs.supplierOptions) 486 } 487 if canRetrieveSecurityCredentialFromEnvironment() { 488 return &AwsSecurityCredentials{ 489 AccessKeyID: getenv(awsAccessKeyId), 490 SecretAccessKey: getenv(awsSecretAccessKey), 491 SessionToken: getenv(awsSessionToken), 492 }, nil 493 } 494 495 roleName, err := cs.getMetadataRoleName(headers) 496 if err != nil { 497 return 498 } 499 500 credentials, err := cs.getMetadataSecurityCredentials(roleName, headers) 501 if err != nil { 502 return 503 } 504 505 if credentials.AccessKeyID == "" { 506 return result, errors.New("oauth2/google/externalaccount: missing AccessKeyId credential") 507 } 508 509 if credentials.SecretAccessKey == "" { 510 return result, errors.New("oauth2/google/externalaccount: missing SecretAccessKey credential") 511 } 512 513 return &credentials, nil 514 } 515 516 func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (AwsSecurityCredentials, error) { 517 var result AwsSecurityCredentials 518 519 req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.credVerificationURL, roleName), nil) 520 if err != nil { 521 return result, err 522 } 523 req.Header.Add("Content-Type", "application/json") 524 525 for name, value := range headers { 526 req.Header.Add(name, value) 527 } 528 529 resp, err := cs.doRequest(req) 530 if err != nil { 531 return result, err 532 } 533 defer resp.Body.Close() 534 535 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) 536 if err != nil { 537 return result, err 538 } 539 540 if resp.StatusCode != 200 { 541 return result, fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS security credentials - %s", string(respBody)) 542 } 543 544 err = json.Unmarshal(respBody, &result) 545 return result, err 546 } 547 548 func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) { 549 if cs.credVerificationURL == "" { 550 return "", errors.New("oauth2/google/externalaccount: unable to determine the AWS metadata server security credentials endpoint") 551 } 552 553 req, err := http.NewRequest("GET", cs.credVerificationURL, nil) 554 if err != nil { 555 return "", err 556 } 557 558 for name, value := range headers { 559 req.Header.Add(name, value) 560 } 561 562 resp, err := cs.doRequest(req) 563 if err != nil { 564 return "", err 565 } 566 defer resp.Body.Close() 567 568 respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) 569 if err != nil { 570 return "", err 571 } 572 573 if resp.StatusCode != 200 { 574 return "", fmt.Errorf("oauth2/google/externalaccount: unable to retrieve AWS role name - %s", string(respBody)) 575 } 576 577 return string(respBody), nil 578 }