github.com/aavshr/aws-sdk-go@v1.41.3/aws/request/request_test.go (about) 1 package request_test 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "net" 11 "net/http" 12 "net/http/httptest" 13 "net/url" 14 "reflect" 15 "runtime" 16 "strconv" 17 "strings" 18 "testing" 19 "time" 20 21 "github.com/aavshr/aws-sdk-go/aws" 22 "github.com/aavshr/aws-sdk-go/aws/awserr" 23 "github.com/aavshr/aws-sdk-go/aws/corehandlers" 24 "github.com/aavshr/aws-sdk-go/aws/credentials" 25 "github.com/aavshr/aws-sdk-go/aws/request" 26 "github.com/aavshr/aws-sdk-go/awstesting" 27 "github.com/aavshr/aws-sdk-go/awstesting/unit" 28 "github.com/aavshr/aws-sdk-go/private/protocol/rest" 29 ) 30 31 type tempNetworkError struct { 32 op string 33 msg string 34 isTemp bool 35 } 36 37 func (e *tempNetworkError) Temporary() bool { return e.isTemp } 38 func (e *tempNetworkError) Error() string { 39 return fmt.Sprintf("%s: %s", e.op, e.msg) 40 } 41 42 var ( 43 // net.OpError accept, are always temporary 44 errAcceptConnectionResetStub = &tempNetworkError{ 45 isTemp: true, op: "accept", msg: "connection reset", 46 } 47 48 // net.OpError read for ECONNRESET is not temporary. 49 errReadConnectionResetStub = &tempNetworkError{ 50 isTemp: false, op: "read", msg: "connection reset", 51 } 52 53 // net.OpError write for ECONNRESET may not be temporary, but is treaded as 54 // temporary by the SDK. 55 errWriteConnectionResetStub = &tempNetworkError{ 56 isTemp: false, op: "write", msg: "connection reset", 57 } 58 59 // net.OpError write for broken pipe may not be temporary, but is treaded as 60 // temporary by the SDK. 61 errWriteBrokenPipeStub = &tempNetworkError{ 62 isTemp: false, op: "write", msg: "broken pipe", 63 } 64 65 // Generic connection reset error 66 errConnectionResetStub = errors.New("connection reset") 67 68 // use of closed network connection error 69 errUseOfClosedConnectionStub = errors.New("use of closed network connection") 70 ) 71 72 type testData struct { 73 Data string 74 } 75 76 func body(str string) io.ReadCloser { 77 return ioutil.NopCloser(bytes.NewReader([]byte(str))) 78 } 79 80 func unmarshal(req *request.Request) { 81 defer req.HTTPResponse.Body.Close() 82 if req.Data != nil { 83 json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data) 84 } 85 } 86 87 func unmarshalError(req *request.Request) { 88 bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body) 89 if err != nil { 90 req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err) 91 return 92 } 93 if len(bodyBytes) == 0 { 94 req.Error = awserr.NewRequestFailure( 95 awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")), 96 req.HTTPResponse.StatusCode, 97 "", 98 ) 99 return 100 } 101 var jsonErr jsonErrorResponse 102 if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil { 103 req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err) 104 return 105 } 106 req.Error = awserr.NewRequestFailure( 107 awserr.New(jsonErr.Code, jsonErr.Message, nil), 108 req.HTTPResponse.StatusCode, 109 "", 110 ) 111 } 112 113 type jsonErrorResponse struct { 114 Code string `json:"__type"` 115 Message string `json:"message"` 116 } 117 118 // test that retries occur for 5xx status codes 119 func TestRequestRecoverRetry5xx(t *testing.T) { 120 reqNum := 0 121 reqs := []http.Response{ 122 {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, 123 {StatusCode: 502, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, 124 {StatusCode: 200, Body: body(`{"data":"valid"}`)}, 125 } 126 127 s := awstesting.NewClient(&aws.Config{ 128 MaxRetries: aws.Int(10), 129 SleepDelay: func(time.Duration) {}, 130 }) 131 s.Handlers.Validate.Clear() 132 s.Handlers.Unmarshal.PushBack(unmarshal) 133 s.Handlers.UnmarshalError.PushBack(unmarshalError) 134 s.Handlers.Send.Clear() // mock sending 135 s.Handlers.Send.PushBack(func(r *request.Request) { 136 r.HTTPResponse = &reqs[reqNum] 137 reqNum++ 138 }) 139 out := &testData{} 140 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 141 err := r.Send() 142 if err != nil { 143 t.Fatalf("expect no error, but got %v", err) 144 } 145 if e, a := 2, r.RetryCount; e != a { 146 t.Errorf("expect %d retry count, got %d", e, a) 147 } 148 if e, a := "valid", out.Data; e != a { 149 t.Errorf("expect %q output got %q", e, a) 150 } 151 } 152 153 // test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry` 154 func TestRequestRecoverRetry4xxRetryable(t *testing.T) { 155 reqNum := 0 156 reqs := []http.Response{ 157 {StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)}, 158 {StatusCode: 400, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)}, 159 {StatusCode: 429, Body: body(`{"__type":"FooException","message":"Rate exceeded."}`)}, 160 {StatusCode: 200, Body: body(`{"data":"valid"}`)}, 161 } 162 163 s := awstesting.NewClient(&aws.Config{ 164 MaxRetries: aws.Int(10), 165 SleepDelay: func(time.Duration) {}, 166 }) 167 s.Handlers.Validate.Clear() 168 s.Handlers.Unmarshal.PushBack(unmarshal) 169 s.Handlers.UnmarshalError.PushBack(unmarshalError) 170 s.Handlers.Send.Clear() // mock sending 171 s.Handlers.Send.PushBack(func(r *request.Request) { 172 r.HTTPResponse = &reqs[reqNum] 173 reqNum++ 174 }) 175 out := &testData{} 176 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 177 err := r.Send() 178 if err != nil { 179 t.Fatalf("expect no error, but got %v", err) 180 } 181 if e, a := 3, r.RetryCount; e != a { 182 t.Errorf("expect %d retry count, got %d", e, a) 183 } 184 if e, a := "valid", out.Data; e != a { 185 t.Errorf("expect %q output got %q", e, a) 186 } 187 } 188 189 // test that retries don't occur for 4xx status codes with a response type that can't be retried 190 func TestRequest4xxUnretryable(t *testing.T) { 191 s := awstesting.NewClient(&aws.Config{ 192 MaxRetries: aws.Int(1), 193 SleepDelay: func(time.Duration) {}, 194 }) 195 s.Handlers.Validate.Clear() 196 s.Handlers.Unmarshal.PushBack(unmarshal) 197 s.Handlers.UnmarshalError.PushBack(unmarshalError) 198 s.Handlers.Send.Clear() // mock sending 199 s.Handlers.Send.PushBack(func(r *request.Request) { 200 r.HTTPResponse = &http.Response{ 201 StatusCode: 401, 202 Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`), 203 } 204 }) 205 out := &testData{} 206 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 207 err := r.Send() 208 if err == nil { 209 t.Fatalf("expect error, but did not get one") 210 } 211 aerr := err.(awserr.RequestFailure) 212 if e, a := 401, aerr.StatusCode(); e != a { 213 t.Errorf("expect %d status code, got %d", e, a) 214 } 215 if e, a := "SignatureDoesNotMatch", aerr.Code(); e != a { 216 t.Errorf("expect %q error code, got %q", e, a) 217 } 218 if e, a := "Signature does not match.", aerr.Message(); e != a { 219 t.Errorf("expect %q error message, got %q", e, a) 220 } 221 if e, a := 0, r.RetryCount; e != a { 222 t.Errorf("expect %d retry count, got %d", e, a) 223 } 224 } 225 226 func TestRequestExhaustRetries(t *testing.T) { 227 delays := []time.Duration{} 228 sleepDelay := func(delay time.Duration) { 229 delays = append(delays, delay) 230 } 231 232 reqNum := 0 233 reqs := []http.Response{ 234 {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, 235 {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, 236 {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, 237 {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, 238 } 239 240 s := awstesting.NewClient(&aws.Config{ 241 SleepDelay: sleepDelay, 242 }) 243 s.Handlers.Validate.Clear() 244 s.Handlers.Unmarshal.PushBack(unmarshal) 245 s.Handlers.UnmarshalError.PushBack(unmarshalError) 246 s.Handlers.Send.Clear() // mock sending 247 s.Handlers.Send.PushBack(func(r *request.Request) { 248 r.HTTPResponse = &reqs[reqNum] 249 reqNum++ 250 }) 251 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 252 err := r.Send() 253 if err == nil { 254 t.Fatalf("expect error, but did not get one") 255 } 256 aerr := err.(awserr.RequestFailure) 257 if e, a := 500, aerr.StatusCode(); e != a { 258 t.Errorf("expect %d status code, got %d", e, a) 259 } 260 if e, a := "UnknownError", aerr.Code(); e != a { 261 t.Errorf("expect %q error code, got %q", e, a) 262 } 263 if e, a := "An error occurred.", aerr.Message(); e != a { 264 t.Errorf("expect %q error message, got %q", e, a) 265 } 266 if e, a := 3, r.RetryCount; e != a { 267 t.Errorf("expect %d retry count, got %d", e, a) 268 } 269 270 expectDelays := []struct{ min, max time.Duration }{{30, 60}, {60, 120}, {120, 240}} 271 for i, v := range delays { 272 min := expectDelays[i].min * time.Millisecond 273 max := expectDelays[i].max * time.Millisecond 274 if !(min <= v && v <= max) { 275 t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", 276 i, v, min, max) 277 } 278 } 279 } 280 281 // test that the request is retried after the credentials are expired. 282 func TestRequestRecoverExpiredCreds(t *testing.T) { 283 reqNum := 0 284 reqs := []http.Response{ 285 {StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)}, 286 {StatusCode: 200, Body: body(`{"data":"valid"}`)}, 287 } 288 289 s := awstesting.NewClient(&aws.Config{ 290 MaxRetries: aws.Int(10), 291 Credentials: credentials.NewStaticCredentials("AKID", "SECRET", ""), 292 SleepDelay: func(time.Duration) {}, 293 }) 294 s.Handlers.Validate.Clear() 295 s.Handlers.Unmarshal.PushBack(unmarshal) 296 s.Handlers.UnmarshalError.PushBack(unmarshalError) 297 298 credExpiredBeforeRetry := false 299 credExpiredAfterRetry := false 300 301 s.Handlers.AfterRetry.PushBack(func(r *request.Request) { 302 credExpiredAfterRetry = r.Config.Credentials.IsExpired() 303 }) 304 305 s.Handlers.Sign.Clear() 306 s.Handlers.Sign.PushBack(func(r *request.Request) { 307 r.Config.Credentials.Get() 308 }) 309 s.Handlers.Send.Clear() // mock sending 310 s.Handlers.Send.PushBack(func(r *request.Request) { 311 r.HTTPResponse = &reqs[reqNum] 312 reqNum++ 313 }) 314 out := &testData{} 315 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 316 err := r.Send() 317 if err != nil { 318 t.Fatalf("expect no error, got %v", err) 319 } 320 321 if credExpiredBeforeRetry { 322 t.Errorf("Expect valid creds before retry check") 323 } 324 if !credExpiredAfterRetry { 325 t.Errorf("Expect expired creds after retry check") 326 } 327 if s.Config.Credentials.IsExpired() { 328 t.Errorf("Expect valid creds after cred expired recovery") 329 } 330 331 if e, a := 1, r.RetryCount; e != a { 332 t.Errorf("expect %d retry count, got %d", e, a) 333 } 334 if e, a := "valid", out.Data; e != a { 335 t.Errorf("expect %q output got %q", e, a) 336 } 337 } 338 339 func TestMakeAddtoUserAgentHandler(t *testing.T) { 340 fn := request.MakeAddToUserAgentHandler("name", "version", "extra1", "extra2") 341 r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}} 342 r.HTTPRequest.Header.Set("User-Agent", "foo/bar") 343 fn(r) 344 345 if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); !strings.HasPrefix(a, e) { 346 t.Errorf("expect %q user agent, got %q", e, a) 347 } 348 } 349 350 func TestMakeAddtoUserAgentFreeFormHandler(t *testing.T) { 351 fn := request.MakeAddToUserAgentFreeFormHandler("name/version (extra1; extra2)") 352 r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}} 353 r.HTTPRequest.Header.Set("User-Agent", "foo/bar") 354 fn(r) 355 356 if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); !strings.HasPrefix(a, e) { 357 t.Errorf("expect %q user agent, got %q", e, a) 358 } 359 } 360 361 func TestRequestUserAgent(t *testing.T) { 362 s := awstesting.NewClient(&aws.Config{ 363 Region: aws.String("us-east-1"), 364 }) 365 366 req := s.NewRequest(&request.Operation{Name: "Operation"}, nil, &testData{}) 367 req.HTTPRequest.Header.Set("User-Agent", "foo/bar") 368 if err := req.Build(); err != nil { 369 t.Fatalf("expect no error, got %v", err) 370 } 371 372 expectUA := fmt.Sprintf("foo/bar %s/%s (%s; %s; %s)", 373 aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) 374 if e, a := expectUA, req.HTTPRequest.Header.Get("User-Agent"); !strings.HasPrefix(a, e) { 375 t.Errorf("expect %q user agent, got %q", e, a) 376 } 377 } 378 379 func TestRequestThrottleRetries(t *testing.T) { 380 var delays []time.Duration 381 sleepDelay := func(delay time.Duration) { 382 delays = append(delays, delay) 383 } 384 385 reqNum := 0 386 reqs := []http.Response{ 387 {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, 388 {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, 389 {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, 390 {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, 391 } 392 393 s := awstesting.NewClient(&aws.Config{ 394 SleepDelay: sleepDelay, 395 }) 396 s.Handlers.Validate.Clear() 397 s.Handlers.Unmarshal.PushBack(unmarshal) 398 s.Handlers.UnmarshalError.PushBack(unmarshalError) 399 s.Handlers.Send.Clear() // mock sending 400 s.Handlers.Send.PushBack(func(r *request.Request) { 401 r.HTTPResponse = &reqs[reqNum] 402 reqNum++ 403 }) 404 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) 405 err := r.Send() 406 if err == nil { 407 t.Fatalf("expect error, but did not get one") 408 } 409 aerr := err.(awserr.RequestFailure) 410 if e, a := 500, aerr.StatusCode(); e != a { 411 t.Errorf("expect %d status code, got %d", e, a) 412 } 413 if e, a := "Throttling", aerr.Code(); e != a { 414 t.Errorf("expect %q error code, got %q", e, a) 415 } 416 if e, a := "An error occurred.", aerr.Message(); e != a { 417 t.Errorf("expect %q error message, got %q", e, a) 418 } 419 if e, a := 3, r.RetryCount; e != a { 420 t.Errorf("expect %d retry count, got %d", e, a) 421 } 422 423 expectDelays := []struct{ min, max time.Duration }{{500, 1000}, {1000, 2000}, {2000, 4000}} 424 for i, v := range delays { 425 min := expectDelays[i].min * time.Millisecond 426 max := expectDelays[i].max * time.Millisecond 427 if !(min <= v && v <= max) { 428 t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", 429 i, v, min, max) 430 } 431 } 432 } 433 434 // test that retries occur for request timeouts when response.Body can be nil 435 func TestRequestRecoverTimeoutWithNilBody(t *testing.T) { 436 reqNum := 0 437 reqs := []*http.Response{ 438 {StatusCode: 0, Body: nil}, // body can be nil when requests time out 439 {StatusCode: 200, Body: body(`{"data":"valid"}`)}, 440 } 441 errors := []error{ 442 errTimeout, nil, 443 } 444 445 s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10)) 446 s.Handlers.Validate.Clear() 447 s.Handlers.Unmarshal.PushBack(unmarshal) 448 s.Handlers.UnmarshalError.PushBack(unmarshalError) 449 s.Handlers.AfterRetry.Clear() // force retry on all errors 450 s.Handlers.AfterRetry.PushBack(func(r *request.Request) { 451 if r.Error != nil { 452 r.Error = nil 453 r.Retryable = aws.Bool(true) 454 r.RetryCount++ 455 } 456 }) 457 s.Handlers.Send.Clear() // mock sending 458 s.Handlers.Send.PushBack(func(r *request.Request) { 459 r.HTTPResponse = reqs[reqNum] 460 r.Error = errors[reqNum] 461 reqNum++ 462 }) 463 out := &testData{} 464 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 465 err := r.Send() 466 if err != nil { 467 t.Fatalf("expect no error, but got %v", err) 468 } 469 if e, a := 1, r.RetryCount; e != a { 470 t.Errorf("expect %d retry count, got %d", e, a) 471 } 472 if e, a := "valid", out.Data; e != a { 473 t.Errorf("expect %q output got %q", e, a) 474 } 475 } 476 477 func TestRequestRecoverTimeoutWithNilResponse(t *testing.T) { 478 reqNum := 0 479 reqs := []*http.Response{ 480 nil, 481 {StatusCode: 200, Body: body(`{"data":"valid"}`)}, 482 } 483 errors := []error{ 484 errTimeout, 485 nil, 486 } 487 488 s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10)) 489 s.Handlers.Validate.Clear() 490 s.Handlers.Unmarshal.PushBack(unmarshal) 491 s.Handlers.UnmarshalError.PushBack(unmarshalError) 492 s.Handlers.AfterRetry.Clear() // force retry on all errors 493 s.Handlers.AfterRetry.PushBack(func(r *request.Request) { 494 if r.Error != nil { 495 r.Error = nil 496 r.Retryable = aws.Bool(true) 497 r.RetryCount++ 498 } 499 }) 500 s.Handlers.Send.Clear() // mock sending 501 s.Handlers.Send.PushBack(func(r *request.Request) { 502 r.HTTPResponse = reqs[reqNum] 503 r.Error = errors[reqNum] 504 reqNum++ 505 }) 506 out := &testData{} 507 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 508 err := r.Send() 509 if err != nil { 510 t.Fatalf("expect no error, but got %v", err) 511 } 512 if e, a := 1, r.RetryCount; e != a { 513 t.Errorf("expect %d retry count, got %d", e, a) 514 } 515 if e, a := "valid", out.Data; e != a { 516 t.Errorf("expect %q output got %q", e, a) 517 } 518 } 519 520 func TestRequest_NoBody(t *testing.T) { 521 cases := []string{ 522 "GET", "HEAD", "DELETE", 523 "PUT", "POST", "PATCH", 524 } 525 526 for i, c := range cases { 527 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 528 if v := r.TransferEncoding; len(v) > 0 { 529 t.Errorf("%d, expect no body sent with Transfer-Encoding, %v", i, v) 530 } 531 532 outMsg := []byte(`{"Value": "abc"}`) 533 534 if b, err := ioutil.ReadAll(r.Body); err != nil { 535 t.Fatalf("%d, expect no error reading request body, got %v", i, err) 536 } else if n := len(b); n > 0 { 537 t.Errorf("%d, expect no request body, got %d bytes", i, n) 538 } 539 540 w.Header().Set("Content-Length", strconv.Itoa(len(outMsg))) 541 if _, err := w.Write(outMsg); err != nil { 542 t.Fatalf("%d, expect no error writing server response, got %v", i, err) 543 } 544 })) 545 546 s := awstesting.NewClient(&aws.Config{ 547 Region: aws.String("mock-region"), 548 MaxRetries: aws.Int(0), 549 Endpoint: aws.String(server.URL), 550 DisableSSL: aws.Bool(true), 551 }) 552 s.Handlers.Build.PushBack(rest.Build) 553 s.Handlers.Validate.Clear() 554 s.Handlers.Unmarshal.PushBack(unmarshal) 555 s.Handlers.UnmarshalError.PushBack(unmarshalError) 556 557 in := struct { 558 Bucket *string `location:"uri" locationName:"bucket"` 559 Key *string `location:"uri" locationName:"key"` 560 }{ 561 Bucket: aws.String("mybucket"), Key: aws.String("myKey"), 562 } 563 564 out := struct { 565 Value *string 566 }{} 567 568 r := s.NewRequest(&request.Operation{ 569 Name: "OpName", HTTPMethod: c, HTTPPath: "/{bucket}/{key+}", 570 }, &in, &out) 571 572 err := r.Send() 573 server.Close() 574 if err != nil { 575 t.Fatalf("%d, expect no error sending request, got %v", i, err) 576 } 577 } 578 } 579 580 func TestIsSerializationErrorRetryable(t *testing.T) { 581 testCases := []struct { 582 err error 583 expected bool 584 }{ 585 { 586 err: awserr.New(request.ErrCodeSerialization, "foo error", nil), 587 expected: false, 588 }, 589 { 590 err: awserr.New("ErrFoo", "foo error", nil), 591 expected: false, 592 }, 593 { 594 err: nil, 595 expected: false, 596 }, 597 { 598 err: awserr.New(request.ErrCodeSerialization, "foo error", errAcceptConnectionResetStub), 599 expected: true, 600 }, 601 } 602 603 for i, c := range testCases { 604 r := &request.Request{ 605 Error: c.err, 606 } 607 if r.IsErrorRetryable() != c.expected { 608 t.Errorf("Case %d: Expected %v, but received %v", i, c.expected, !c.expected) 609 } 610 } 611 } 612 613 func TestWithLogLevel(t *testing.T) { 614 r := &request.Request{} 615 616 opt := request.WithLogLevel(aws.LogDebugWithHTTPBody) 617 r.ApplyOptions(opt) 618 619 if !r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) { 620 t.Errorf("expect log level to be set, but was not, %v", 621 r.Config.LogLevel.Value()) 622 } 623 } 624 625 func TestWithGetResponseHeader(t *testing.T) { 626 r := &request.Request{} 627 628 var val, val2 string 629 r.ApplyOptions( 630 request.WithGetResponseHeader("x-a-header", &val), 631 request.WithGetResponseHeader("x-second-header", &val2), 632 ) 633 634 r.HTTPResponse = &http.Response{ 635 Header: func() http.Header { 636 h := http.Header{} 637 h.Set("x-a-header", "first") 638 h.Set("x-second-header", "second") 639 return h 640 }(), 641 } 642 r.Handlers.Complete.Run(r) 643 644 if e, a := "first", val; e != a { 645 t.Errorf("expect %q header value got %q", e, a) 646 } 647 if e, a := "second", val2; e != a { 648 t.Errorf("expect %q header value got %q", e, a) 649 } 650 } 651 652 func TestWithGetResponseHeaders(t *testing.T) { 653 r := &request.Request{} 654 655 var headers http.Header 656 opt := request.WithGetResponseHeaders(&headers) 657 658 r.ApplyOptions(opt) 659 660 r.HTTPResponse = &http.Response{ 661 Header: func() http.Header { 662 h := http.Header{} 663 h.Set("x-a-header", "headerValue") 664 return h 665 }(), 666 } 667 r.Handlers.Complete.Run(r) 668 669 if e, a := "headerValue", headers.Get("x-a-header"); e != a { 670 t.Errorf("expect %q header value got %q", e, a) 671 } 672 } 673 674 type testRetryer struct { 675 shouldRetry bool 676 maxRetries int 677 } 678 679 func (d *testRetryer) MaxRetries() int { 680 return d.maxRetries 681 } 682 683 // RetryRules returns the delay duration before retrying this request again 684 func (d *testRetryer) RetryRules(r *request.Request) time.Duration { 685 return 0 686 } 687 688 func (d *testRetryer) ShouldRetry(r *request.Request) bool { 689 return d.shouldRetry 690 } 691 692 func TestEnforceShouldRetryCheck(t *testing.T) { 693 694 retryer := &testRetryer{ 695 shouldRetry: true, maxRetries: 3, 696 } 697 s := awstesting.NewClient(&aws.Config{ 698 Region: aws.String("mock-region"), 699 MaxRetries: aws.Int(0), 700 Retryer: retryer, 701 EnforceShouldRetryCheck: aws.Bool(true), 702 SleepDelay: func(time.Duration) {}, 703 }) 704 705 s.Handlers.Validate.Clear() 706 s.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{ 707 Name: "TestEnforceShouldRetryCheck", 708 Fn: func(r *request.Request) { 709 r.HTTPResponse = &http.Response{ 710 Header: http.Header{}, 711 Body: ioutil.NopCloser(bytes.NewBuffer(nil)), 712 } 713 r.Retryable = aws.Bool(false) 714 }, 715 }) 716 717 s.Handlers.Unmarshal.PushBack(unmarshal) 718 s.Handlers.UnmarshalError.PushBack(unmarshalError) 719 720 out := &testData{} 721 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 722 err := r.Send() 723 if err == nil { 724 t.Fatalf("expect error, but got nil") 725 } 726 if e, a := 3, r.RetryCount; e != a { 727 t.Errorf("expect %d retry count, got %d", e, a) 728 } 729 if !retryer.shouldRetry { 730 t.Errorf("expect 'true' for ShouldRetry, but got %v", retryer.shouldRetry) 731 } 732 } 733 734 type errReader struct { 735 err error 736 } 737 738 func (reader *errReader) Read(b []byte) (int, error) { 739 return 0, reader.err 740 } 741 742 func (reader *errReader) Close() error { 743 return nil 744 } 745 746 func TestIsNoBodyReader(t *testing.T) { 747 cases := []struct { 748 reader io.ReadCloser 749 expect bool 750 }{ 751 {ioutil.NopCloser(bytes.NewReader([]byte("abc"))), false}, 752 {ioutil.NopCloser(bytes.NewReader(nil)), false}, 753 {nil, false}, 754 {request.NoBody, true}, 755 } 756 757 for i, c := range cases { 758 if e, a := c.expect, request.NoBody == c.reader; e != a { 759 t.Errorf("%d, expect %t match, but was %t", i, e, a) 760 } 761 } 762 } 763 764 func TestRequest_TemporaryRetry(t *testing.T) { 765 done := make(chan struct{}) 766 767 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 768 w.Header().Set("Content-Length", "1024") 769 w.WriteHeader(http.StatusOK) 770 771 w.Write(make([]byte, 100)) 772 773 f := w.(http.Flusher) 774 f.Flush() 775 776 <-done 777 })) 778 defer server.Close() 779 780 client := &http.Client{ 781 Timeout: 100 * time.Millisecond, 782 } 783 784 svc := awstesting.NewClient(&aws.Config{ 785 Region: unit.Session.Config.Region, 786 MaxRetries: aws.Int(1), 787 HTTPClient: client, 788 DisableSSL: aws.Bool(true), 789 Endpoint: aws.String(server.URL), 790 }) 791 792 req := svc.NewRequest(&request.Operation{ 793 Name: "name", HTTPMethod: "GET", HTTPPath: "/path", 794 }, &struct{}{}, &struct{}{}) 795 796 req.Handlers.Unmarshal.PushBack(func(r *request.Request) { 797 defer req.HTTPResponse.Body.Close() 798 _, err := io.Copy(ioutil.Discard, req.HTTPResponse.Body) 799 r.Error = awserr.New(request.ErrCodeSerialization, "error", err) 800 }) 801 802 err := req.Send() 803 if err == nil { 804 t.Errorf("expect error, got none") 805 } 806 close(done) 807 808 aerr := err.(awserr.Error) 809 if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { 810 t.Errorf("expect %q error code, got %q", e, a) 811 } 812 813 if e, a := 1, req.RetryCount; e != a { 814 t.Errorf("expect %d retries, got %d", e, a) 815 } 816 817 type temporary interface { 818 Temporary() bool 819 } 820 821 terr := aerr.OrigErr().(temporary) 822 if !terr.Temporary() { 823 t.Errorf("expect temporary error, was not") 824 } 825 } 826 827 func TestRequest_Presign(t *testing.T) { 828 presign := func(r *request.Request, expire time.Duration) (string, http.Header, error) { 829 u, err := r.Presign(expire) 830 return u, nil, err 831 } 832 presignRequest := func(r *request.Request, expire time.Duration) (string, http.Header, error) { 833 return r.PresignRequest(expire) 834 } 835 mustParseURL := func(v string) *url.URL { 836 u, err := url.Parse(v) 837 if err != nil { 838 panic(err) 839 } 840 return u 841 } 842 843 cases := []struct { 844 Expire time.Duration 845 PresignFn func(*request.Request, time.Duration) (string, http.Header, error) 846 SignerFn func(*request.Request) 847 URL string 848 Header http.Header 849 Err string 850 }{ 851 { 852 PresignFn: presign, 853 Err: request.ErrCodeInvalidPresignExpire, 854 }, 855 { 856 PresignFn: presignRequest, 857 Err: request.ErrCodeInvalidPresignExpire, 858 }, 859 { 860 Expire: -1, 861 PresignFn: presign, 862 Err: request.ErrCodeInvalidPresignExpire, 863 }, 864 { 865 // Presign clear NotHoist 866 Expire: 1 * time.Minute, 867 PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) { 868 r.NotHoist = true 869 return presign(r, dur) 870 }, 871 SignerFn: func(r *request.Request) { 872 r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL") 873 if r.NotHoist { 874 r.Error = fmt.Errorf("expect NotHoist to be cleared") 875 } 876 }, 877 URL: "https://endpoint/presignedURL", 878 }, 879 { 880 // PresignRequest does not clear NotHoist 881 Expire: 1 * time.Minute, 882 PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) { 883 r.NotHoist = true 884 return presignRequest(r, dur) 885 }, 886 SignerFn: func(r *request.Request) { 887 r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL") 888 if !r.NotHoist { 889 r.Error = fmt.Errorf("expect NotHoist not to be cleared") 890 } 891 }, 892 URL: "https://endpoint/presignedURL", 893 }, 894 { 895 // PresignRequest returns signed headers 896 Expire: 1 * time.Minute, 897 PresignFn: presignRequest, 898 SignerFn: func(r *request.Request) { 899 r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL") 900 r.HTTPRequest.Header.Set("UnsigndHeader", "abc") 901 r.SignedHeaderVals = http.Header{ 902 "X-Amzn-Header": []string{"abc", "123"}, 903 "X-Amzn-Header2": []string{"efg", "456"}, 904 } 905 }, 906 URL: "https://endpoint/presignedURL", 907 Header: http.Header{ 908 "X-Amzn-Header": []string{"abc", "123"}, 909 "X-Amzn-Header2": []string{"efg", "456"}, 910 }, 911 }, 912 } 913 914 svc := awstesting.NewClient() 915 svc.Handlers.Clear() 916 for i, c := range cases { 917 req := svc.NewRequest(&request.Operation{ 918 Name: "name", HTTPMethod: "GET", HTTPPath: "/path", 919 }, &struct{}{}, &struct{}{}) 920 req.Handlers.Sign.PushBack(c.SignerFn) 921 922 u, h, err := c.PresignFn(req, c.Expire) 923 if len(c.Err) != 0 { 924 if e, a := c.Err, err.Error(); !strings.Contains(a, e) { 925 t.Errorf("%d, expect %v to be in %v", i, e, a) 926 } 927 continue 928 } else if err != nil { 929 t.Errorf("%d, expect no error, got %v", i, err) 930 continue 931 } 932 if e, a := c.URL, u; e != a { 933 t.Errorf("%d, expect %v URL, got %v", i, e, a) 934 } 935 if e, a := c.Header, h; !reflect.DeepEqual(e, a) { 936 t.Errorf("%d, expect %v header got %v", i, e, a) 937 } 938 } 939 } 940 941 func TestSanitizeHostForHeader(t *testing.T) { 942 cases := []struct { 943 url string 944 expectedRequestHost string 945 }{ 946 {"https://estest.us-east-1.es.amazonaws.com:443", "estest.us-east-1.es.amazonaws.com"}, 947 {"https://estest.us-east-1.es.amazonaws.com", "estest.us-east-1.es.amazonaws.com"}, 948 {"https://localhost:9200", "localhost:9200"}, 949 {"http://localhost:80", "localhost"}, 950 {"http://localhost:8080", "localhost:8080"}, 951 } 952 953 for _, c := range cases { 954 r, _ := http.NewRequest("GET", c.url, nil) 955 request.SanitizeHostForHeader(r) 956 957 if h := r.Host; h != c.expectedRequestHost { 958 t.Errorf("expect %v host, got %q", c.expectedRequestHost, h) 959 } 960 } 961 } 962 963 func TestRequestWillRetry_ByBody(t *testing.T) { 964 svc := awstesting.NewClient() 965 966 cases := []struct { 967 WillRetry bool 968 HTTPMethod string 969 Body io.ReadSeeker 970 IsReqNoBody bool 971 }{ 972 { 973 WillRetry: true, 974 HTTPMethod: "GET", 975 Body: bytes.NewReader([]byte{}), 976 IsReqNoBody: true, 977 }, 978 { 979 WillRetry: true, 980 HTTPMethod: "GET", 981 Body: bytes.NewReader(nil), 982 IsReqNoBody: true, 983 }, 984 { 985 WillRetry: true, 986 HTTPMethod: "POST", 987 Body: bytes.NewReader([]byte("abc123")), 988 }, 989 { 990 WillRetry: true, 991 HTTPMethod: "POST", 992 Body: aws.ReadSeekCloser(bytes.NewReader([]byte("abc123"))), 993 }, 994 { 995 WillRetry: true, 996 HTTPMethod: "GET", 997 Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)), 998 IsReqNoBody: true, 999 }, 1000 { 1001 WillRetry: true, 1002 HTTPMethod: "POST", 1003 Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)), 1004 IsReqNoBody: true, 1005 }, 1006 { 1007 WillRetry: false, 1008 HTTPMethod: "POST", 1009 Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc123"))), 1010 }, 1011 } 1012 1013 for i, c := range cases { 1014 req := svc.NewRequest(&request.Operation{ 1015 Name: "Operation", 1016 HTTPMethod: c.HTTPMethod, 1017 HTTPPath: "/", 1018 }, nil, nil) 1019 req.SetReaderBody(c.Body) 1020 req.Build() 1021 1022 req.Error = fmt.Errorf("some error") 1023 req.Retryable = aws.Bool(true) 1024 req.HTTPResponse = &http.Response{ 1025 StatusCode: 500, 1026 } 1027 1028 if e, a := c.IsReqNoBody, request.NoBody == req.HTTPRequest.Body; e != a { 1029 t.Errorf("%d, expect request to be no body, %t, got %t, %T", i, e, a, req.HTTPRequest.Body) 1030 } 1031 1032 if e, a := c.WillRetry, req.WillRetry(); e != a { 1033 t.Errorf("%d, expect %t willRetry, got %t", i, e, a) 1034 } 1035 1036 if req.Error == nil { 1037 t.Fatalf("%d, expect error, got none", i) 1038 } 1039 if e, a := "some error", req.Error.Error(); !strings.Contains(a, e) { 1040 t.Errorf("%d, expect %q error in %q", i, e, a) 1041 } 1042 if e, a := 0, req.RetryCount; e != a { 1043 t.Errorf("%d, expect retry count to be %d, got %d", i, e, a) 1044 } 1045 } 1046 } 1047 1048 func Test501NotRetrying(t *testing.T) { 1049 reqNum := 0 1050 reqs := []http.Response{ 1051 {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, 1052 {StatusCode: 501, Body: body(`{"__type":"NotImplemented","message":"An error occurred."}`)}, 1053 {StatusCode: 200, Body: body(`{"data":"valid"}`)}, 1054 } 1055 1056 s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10)) 1057 s.Handlers.Validate.Clear() 1058 s.Handlers.Unmarshal.PushBack(unmarshal) 1059 s.Handlers.UnmarshalError.PushBack(unmarshalError) 1060 s.Handlers.Send.Clear() // mock sending 1061 s.Handlers.Send.PushBack(func(r *request.Request) { 1062 r.HTTPResponse = &reqs[reqNum] 1063 reqNum++ 1064 }) 1065 out := &testData{} 1066 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 1067 err := r.Send() 1068 if err == nil { 1069 t.Fatal("expect error, but got none") 1070 } 1071 1072 aerr := err.(awserr.Error) 1073 if e, a := "NotImplemented", aerr.Code(); e != a { 1074 t.Errorf("expected error code %q, but received %q", e, a) 1075 } 1076 if e, a := 1, r.RetryCount; e != a { 1077 t.Errorf("expect %d retry count, got %d", e, a) 1078 } 1079 } 1080 1081 func TestRequestNoConnection(t *testing.T) { 1082 port, err := getFreePort() 1083 if err != nil { 1084 t.Fatalf("failed to get free port for test") 1085 } 1086 s := awstesting.NewClient(aws.NewConfig(). 1087 WithMaxRetries(10). 1088 WithEndpoint("https://localhost:" + strconv.Itoa(port)). 1089 WithSleepDelay(func(time.Duration) {}), 1090 ) 1091 s.Handlers.Validate.Clear() 1092 s.Handlers.Unmarshal.PushBack(unmarshal) 1093 s.Handlers.UnmarshalError.PushBack(unmarshalError) 1094 1095 out := &testData{} 1096 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 1097 1098 if err = r.Send(); err == nil { 1099 t.Fatal("expect error, but got none") 1100 } 1101 1102 t.Logf("Error, %v", err) 1103 awsError := err.(awserr.Error) 1104 origError := awsError.OrigErr() 1105 t.Logf("Orig Error: %#v of type %T", origError, origError) 1106 1107 if e, a := 10, r.RetryCount; e != a { 1108 t.Errorf("expect %v retry count, got %v", e, a) 1109 } 1110 } 1111 1112 func TestRequestBodySeekFails(t *testing.T) { 1113 s := awstesting.NewClient() 1114 s.Handlers.Validate.Clear() 1115 s.Handlers.Build.Clear() 1116 1117 out := &testData{} 1118 r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) 1119 r.SetReaderBody(&stubSeekFail{ 1120 Err: fmt.Errorf("failed to seek reader"), 1121 }) 1122 err := r.Send() 1123 if err == nil { 1124 t.Fatal("expect error, but got none") 1125 } 1126 1127 aerr := err.(awserr.Error) 1128 if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { 1129 t.Errorf("expect %v error code, got %v", e, a) 1130 } 1131 1132 } 1133 1134 func TestRequestEndpointWithDefaultPort(t *testing.T) { 1135 s := awstesting.NewClient(&aws.Config{ 1136 Endpoint: aws.String("https://example.test:443"), 1137 }) 1138 r := s.NewRequest(&request.Operation{ 1139 Name: "FooBar", 1140 HTTPMethod: "GET", 1141 HTTPPath: "/", 1142 }, nil, nil) 1143 r.Handlers.Validate.Clear() 1144 r.Handlers.ValidateResponse.Clear() 1145 r.Handlers.Send.Clear() 1146 r.Handlers.Send.PushFront(func(r *request.Request) { 1147 req := r.HTTPRequest 1148 1149 if e, a := "example.test", req.Host; e != a { 1150 t.Errorf("expected %v, got %v", e, a) 1151 } 1152 1153 if e, a := "https://example.test:443/", req.URL.String(); e != a { 1154 t.Errorf("expected %v, got %v", e, a) 1155 } 1156 }) 1157 err := r.Send() 1158 if err != nil { 1159 t.Fatalf("expected no error, got %v", err) 1160 } 1161 } 1162 1163 func TestRequestEndpointWithNonDefaultPort(t *testing.T) { 1164 s := awstesting.NewClient(&aws.Config{ 1165 Endpoint: aws.String("https://example.test:8443"), 1166 }) 1167 r := s.NewRequest(&request.Operation{ 1168 Name: "FooBar", 1169 HTTPMethod: "GET", 1170 HTTPPath: "/", 1171 }, nil, nil) 1172 r.Handlers.Validate.Clear() 1173 r.Handlers.ValidateResponse.Clear() 1174 r.Handlers.Send.Clear() 1175 r.Handlers.Send.PushFront(func(r *request.Request) { 1176 req := r.HTTPRequest 1177 1178 // http.Request.Host should not be set for non-default ports 1179 if e, a := "", req.Host; e != a { 1180 t.Errorf("expected %v, got %v", e, a) 1181 } 1182 1183 if e, a := "https://example.test:8443/", req.URL.String(); e != a { 1184 t.Errorf("expected %v, got %v", e, a) 1185 } 1186 }) 1187 err := r.Send() 1188 if err != nil { 1189 t.Fatalf("expected no error, got %v", err) 1190 } 1191 } 1192 1193 func TestRequestMarshaledEndpointWithDefaultPort(t *testing.T) { 1194 s := awstesting.NewClient(&aws.Config{ 1195 Endpoint: aws.String("https://example.test:443"), 1196 }) 1197 r := s.NewRequest(&request.Operation{ 1198 Name: "FooBar", 1199 HTTPMethod: "GET", 1200 HTTPPath: "/", 1201 }, nil, nil) 1202 r.Handlers.Validate.Clear() 1203 r.Handlers.ValidateResponse.Clear() 1204 r.Handlers.Build.PushBack(func(r *request.Request) { 1205 req := r.HTTPRequest 1206 req.URL.Host = "foo." + req.URL.Host 1207 }) 1208 r.Handlers.Send.Clear() 1209 r.Handlers.Send.PushFront(func(r *request.Request) { 1210 req := r.HTTPRequest 1211 1212 if e, a := "foo.example.test", req.Host; e != a { 1213 t.Errorf("expected %v, got %v", e, a) 1214 } 1215 1216 if e, a := "https://foo.example.test:443/", req.URL.String(); e != a { 1217 t.Errorf("expected %v, got %v", e, a) 1218 } 1219 }) 1220 err := r.Send() 1221 if err != nil { 1222 t.Fatalf("expected no error, got %v", err) 1223 } 1224 } 1225 1226 func TestRequestMarshaledEndpointWithNonDefaultPort(t *testing.T) { 1227 s := awstesting.NewClient(&aws.Config{ 1228 Endpoint: aws.String("https://example.test:8443"), 1229 }) 1230 r := s.NewRequest(&request.Operation{ 1231 Name: "FooBar", 1232 HTTPMethod: "GET", 1233 HTTPPath: "/", 1234 }, nil, nil) 1235 r.Handlers.Validate.Clear() 1236 r.Handlers.ValidateResponse.Clear() 1237 r.Handlers.Build.PushBack(func(r *request.Request) { 1238 req := r.HTTPRequest 1239 req.URL.Host = "foo." + req.URL.Host 1240 }) 1241 r.Handlers.Send.Clear() 1242 r.Handlers.Send.PushFront(func(r *request.Request) { 1243 req := r.HTTPRequest 1244 1245 // http.Request.Host should not be set for non-default ports 1246 if e, a := "", req.Host; e != a { 1247 t.Errorf("expected %v, got %v", e, a) 1248 } 1249 1250 if e, a := "https://foo.example.test:8443/", req.URL.String(); e != a { 1251 t.Errorf("expected %v, got %v", e, a) 1252 } 1253 }) 1254 err := r.Send() 1255 if err != nil { 1256 t.Fatalf("expected no error, got %v", err) 1257 } 1258 } 1259 1260 type stubSeekFail struct { 1261 Err error 1262 } 1263 1264 func (f *stubSeekFail) Read(b []byte) (int, error) { 1265 return len(b), nil 1266 } 1267 func (f *stubSeekFail) ReadAt(b []byte, offset int64) (int, error) { 1268 return len(b), nil 1269 } 1270 func (f *stubSeekFail) Seek(offset int64, mode int) (int64, error) { 1271 return 0, f.Err 1272 } 1273 1274 func getFreePort() (int, error) { 1275 l, err := net.Listen("tcp", ":0") 1276 if err != nil { 1277 return 0, err 1278 } 1279 defer l.Close() 1280 1281 strAddr := l.Addr().String() 1282 parts := strings.Split(strAddr, ":") 1283 strPort := parts[len(parts)-1] 1284 port, err := strconv.ParseInt(strPort, 10, 32) 1285 if err != nil { 1286 return 0, err 1287 } 1288 return int(port), nil 1289 }