github.com/aavshr/aws-sdk-go@v1.41.3/aws/corehandlers/handlers_test.go (about)

     1  package corehandlers_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"net/url"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/aavshr/aws-sdk-go/aws"
    15  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    16  	"github.com/aavshr/aws-sdk-go/aws/client"
    17  	"github.com/aavshr/aws-sdk-go/aws/client/metadata"
    18  	"github.com/aavshr/aws-sdk-go/aws/corehandlers"
    19  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    20  	"github.com/aavshr/aws-sdk-go/aws/request"
    21  	"github.com/aavshr/aws-sdk-go/awstesting"
    22  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    23  	"github.com/aavshr/aws-sdk-go/internal/sdktesting"
    24  	"github.com/aavshr/aws-sdk-go/service/s3"
    25  )
    26  
    27  func TestValidateEndpointHandler(t *testing.T) {
    28  	restoreEnvFn := sdktesting.StashEnv()
    29  	defer restoreEnvFn()
    30  	svc := awstesting.NewClient(aws.NewConfig().WithRegion("us-west-2"))
    31  	svc.Handlers.Clear()
    32  	svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
    33  
    34  	req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
    35  	err := req.Build()
    36  
    37  	if err != nil {
    38  		t.Errorf("expect no error, got %v", err)
    39  	}
    40  }
    41  
    42  func TestValidateEndpointHandlerErrorRegion(t *testing.T) {
    43  	restoreEnvFn := sdktesting.StashEnv()
    44  	defer restoreEnvFn()
    45  	svc := awstesting.NewClient()
    46  	svc.Handlers.Clear()
    47  	svc.Handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
    48  
    49  	req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
    50  	err := req.Build()
    51  
    52  	if err == nil {
    53  		t.Errorf("expect error, got none")
    54  	}
    55  	if e, a := aws.ErrMissingRegion, err; e != a {
    56  		t.Errorf("expect %v to be %v", e, a)
    57  	}
    58  }
    59  
    60  type mockCredsProvider struct {
    61  	expired        bool
    62  	retrieveCalled bool
    63  }
    64  
    65  func (m *mockCredsProvider) Retrieve() (credentials.Value, error) {
    66  	m.retrieveCalled = true
    67  	return credentials.Value{
    68  		AccessKeyID:     "AKID",
    69  		SecretAccessKey: "SECRET",
    70  		ProviderName:    "mockCredsProvider",
    71  	}, nil
    72  }
    73  
    74  func (m *mockCredsProvider) IsExpired() bool {
    75  	return m.expired
    76  }
    77  
    78  func TestAfterRetryRefreshCreds(t *testing.T) {
    79  	restoreEnvFn := sdktesting.StashEnv()
    80  	defer restoreEnvFn()
    81  
    82  	credProvider := &mockCredsProvider{}
    83  
    84  	sess := unit.Session.Copy(&aws.Config{
    85  		Credentials: credentials.NewCredentials(credProvider),
    86  		MaxRetries:  aws.Int(2),
    87  	})
    88  	clientInfo := metadata.ClientInfo{
    89  		Endpoint:    "http://endpoint",
    90  		SigningName: "",
    91  	}
    92  	svc := client.New(*sess.Config, clientInfo, sess.Handlers)
    93  
    94  	svc.Handlers.Sign.PushBack(func(r *request.Request) {
    95  		if !svc.Config.Credentials.IsExpired() {
    96  			t.Errorf("expect credentials of of been expired before request attempt")
    97  		}
    98  		_, err := svc.Config.Credentials.Get()
    99  		r.Error = err
   100  	})
   101  
   102  	var respID int
   103  	resps := []struct {
   104  		Resp *http.Response
   105  		Err  error
   106  	}{
   107  		{
   108  			Resp: &http.Response{
   109  				StatusCode: 403,
   110  				Header:     http.Header{},
   111  				Body:       ioutil.NopCloser(bytes.NewBuffer([]byte{})),
   112  			},
   113  			Err: awserr.New("ExpiredToken", "", nil),
   114  		},
   115  		{
   116  			Resp: &http.Response{
   117  				StatusCode: 403,
   118  				Header:     http.Header{},
   119  				Body:       ioutil.NopCloser(bytes.NewBuffer([]byte{})),
   120  			},
   121  			Err: awserr.New("ExpiredToken", "", nil),
   122  		},
   123  		{
   124  			Resp: &http.Response{
   125  				StatusCode: 200,
   126  				Header:     http.Header{},
   127  				Body:       ioutil.NopCloser(bytes.NewBuffer([]byte{})),
   128  			},
   129  		},
   130  	}
   131  	svc.Handlers.Send.Clear()
   132  	svc.Handlers.Send.PushBack(func(r *request.Request) {
   133  		r.HTTPResponse = resps[respID].Resp
   134  	})
   135  	svc.Handlers.UnmarshalError.PushBack(func(r *request.Request) {
   136  		r.Error = resps[respID].Err
   137  	})
   138  	svc.Handlers.CompleteAttempt.PushBack(func(r *request.Request) {
   139  		respID++
   140  	})
   141  
   142  	if !svc.Config.Credentials.IsExpired() {
   143  		t.Fatalf("expect to start out expired")
   144  	}
   145  	if credProvider.retrieveCalled {
   146  		t.Fatalf("expect retrieve not yet called")
   147  	}
   148  
   149  	req := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
   150  	if err := req.Send(); err != nil {
   151  		t.Fatalf("expect no error, got %v", err)
   152  	}
   153  	if e, a := len(resps)-1, req.RetryCount; e != a {
   154  		t.Errorf("expect %v retries, got %v", e, a)
   155  	}
   156  	if svc.Config.Credentials.IsExpired() {
   157  		t.Errorf("expect credentials not to be expired")
   158  	}
   159  	if !credProvider.retrieveCalled {
   160  		t.Errorf("expect retrieve to be called")
   161  	}
   162  }
   163  
   164  func TestAfterRetryWithContextCanceled(t *testing.T) {
   165  	c := awstesting.NewClient()
   166  
   167  	req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
   168  
   169  	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
   170  	req.SetContext(ctx)
   171  
   172  	req.Error = fmt.Errorf("some error")
   173  	req.Retryable = aws.Bool(true)
   174  	req.HTTPResponse = &http.Response{
   175  		StatusCode: 500,
   176  	}
   177  
   178  	close(ctx.DoneCh)
   179  	ctx.Error = fmt.Errorf("context canceled")
   180  
   181  	corehandlers.AfterRetryHandler.Fn(req)
   182  
   183  	if req.Error == nil {
   184  		t.Fatalf("expect error but didn't receive one")
   185  	}
   186  
   187  	aerr := req.Error.(awserr.Error)
   188  
   189  	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
   190  		t.Errorf("expect %q, error code got %q", e, a)
   191  	}
   192  }
   193  
   194  func TestAfterRetryWithContext(t *testing.T) {
   195  	c := awstesting.NewClient()
   196  
   197  	req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
   198  
   199  	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
   200  	req.SetContext(ctx)
   201  
   202  	req.Error = fmt.Errorf("some error")
   203  	req.Retryable = aws.Bool(true)
   204  	req.HTTPResponse = &http.Response{
   205  		StatusCode: 500,
   206  	}
   207  
   208  	corehandlers.AfterRetryHandler.Fn(req)
   209  
   210  	if req.Error != nil {
   211  		t.Fatalf("expect no error, got %v", req.Error)
   212  	}
   213  	if e, a := 1, req.RetryCount; e != a {
   214  		t.Errorf("expect retry count to be %d, got %d", e, a)
   215  	}
   216  }
   217  
   218  func TestSendWithContextCanceled(t *testing.T) {
   219  	c := awstesting.NewClient(&aws.Config{
   220  		SleepDelay: func(dur time.Duration) {
   221  			t.Errorf("SleepDelay should not be called")
   222  		},
   223  	})
   224  
   225  	req := c.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
   226  
   227  	ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
   228  	req.SetContext(ctx)
   229  
   230  	req.Error = fmt.Errorf("some error")
   231  	req.Retryable = aws.Bool(true)
   232  	req.HTTPResponse = &http.Response{
   233  		StatusCode: 500,
   234  	}
   235  
   236  	close(ctx.DoneCh)
   237  	ctx.Error = fmt.Errorf("context canceled")
   238  
   239  	corehandlers.SendHandler.Fn(req)
   240  
   241  	if req.Error == nil {
   242  		t.Fatalf("expect error but didn't receive one")
   243  	}
   244  
   245  	aerr := req.Error.(awserr.Error)
   246  
   247  	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
   248  		t.Errorf("expect %q, error code got %q", e, a)
   249  	}
   250  }
   251  
   252  type testSendHandlerTransport struct{}
   253  
   254  func (t *testSendHandlerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
   255  	return nil, fmt.Errorf("mock error")
   256  }
   257  
   258  func TestSendHandlerError(t *testing.T) {
   259  	svc := awstesting.NewClient(&aws.Config{
   260  		HTTPClient: &http.Client{
   261  			Transport: &testSendHandlerTransport{},
   262  		},
   263  	})
   264  	svc.Handlers.Clear()
   265  	svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
   266  	r := svc.NewRequest(&request.Operation{Name: "Operation"}, nil, nil)
   267  
   268  	r.Send()
   269  
   270  	if r.Error == nil {
   271  		t.Errorf("expect error, got none")
   272  	}
   273  	if r.HTTPResponse == nil {
   274  		t.Errorf("expect response, got none")
   275  	}
   276  }
   277  
   278  func TestSendWithoutFollowRedirects(t *testing.T) {
   279  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   280  		switch r.URL.Path {
   281  		case "/original":
   282  			w.Header().Set("Location", "/redirected")
   283  			w.WriteHeader(301)
   284  		case "/redirected":
   285  			t.Fatalf("expect not to redirect, but was")
   286  		}
   287  	}))
   288  	defer server.Close()
   289  
   290  	svc := awstesting.NewClient(&aws.Config{
   291  		DisableSSL: aws.Bool(true),
   292  		Endpoint:   aws.String(server.URL),
   293  	})
   294  	svc.Handlers.Clear()
   295  	svc.Handlers.Send.PushBackNamed(corehandlers.SendHandler)
   296  
   297  	r := svc.NewRequest(&request.Operation{
   298  		Name:     "Operation",
   299  		HTTPPath: "/original",
   300  	}, nil, nil)
   301  	r.DisableFollowRedirects = true
   302  
   303  	err := r.Send()
   304  	if err != nil {
   305  		t.Errorf("expect no error, got %v", err)
   306  	}
   307  	if e, a := 301, r.HTTPResponse.StatusCode; e != a {
   308  		t.Errorf("expect %d status code, got %d", e, a)
   309  	}
   310  }
   311  
   312  func TestValidateReqSigHandler(t *testing.T) {
   313  	cases := []struct {
   314  		Req    *request.Request
   315  		Resign bool
   316  	}{
   317  		{
   318  			Req: &request.Request{
   319  				Config: aws.Config{Credentials: credentials.AnonymousCredentials},
   320  				Time:   time.Now().Add(-15 * time.Minute),
   321  			},
   322  			Resign: false,
   323  		},
   324  		{
   325  			Req: &request.Request{
   326  				Time: time.Now().Add(-15 * time.Minute),
   327  			},
   328  			Resign: true,
   329  		},
   330  		{
   331  			Req: &request.Request{
   332  				Time: time.Now().Add(-1 * time.Minute),
   333  			},
   334  			Resign: false,
   335  		},
   336  	}
   337  
   338  	for i, c := range cases {
   339  		c.Req.HTTPRequest = &http.Request{URL: &url.URL{}}
   340  
   341  		resigned := false
   342  		c.Req.Handlers.Sign.PushBack(func(r *request.Request) {
   343  			resigned = true
   344  		})
   345  
   346  		corehandlers.ValidateReqSigHandler.Fn(c.Req)
   347  
   348  		if c.Req.Error != nil {
   349  			t.Errorf("expect no error, got %v", c.Req.Error)
   350  		}
   351  		if e, a := c.Resign, resigned; e != a {
   352  			t.Errorf("%d, expect %v to be %v", i, e, a)
   353  		}
   354  	}
   355  }
   356  
   357  func setupContentLengthTestServer(t *testing.T, hasContentLength bool, contentLength int64) *httptest.Server {
   358  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   359  		_, ok := r.Header["Content-Length"]
   360  		if e, a := hasContentLength, ok; e != a {
   361  			t.Errorf("expect %v to be %v", e, a)
   362  		}
   363  		if hasContentLength {
   364  			if e, a := contentLength, r.ContentLength; e != a {
   365  				t.Errorf("expect %v to be %v", e, a)
   366  			}
   367  		}
   368  
   369  		b, err := ioutil.ReadAll(r.Body)
   370  		if err != nil {
   371  			t.Errorf("expect no error, got %v", err)
   372  		}
   373  		r.Body.Close()
   374  
   375  		authHeader := r.Header.Get("Authorization")
   376  		if hasContentLength {
   377  			if e, a := "content-length", authHeader; !strings.Contains(a, e) {
   378  				t.Errorf("expect %v to be in %v", e, a)
   379  			}
   380  		} else {
   381  			if e, a := "content-length", authHeader; strings.Contains(a, e) {
   382  				t.Errorf("expect %v to not be in %v", e, a)
   383  			}
   384  		}
   385  
   386  		if e, a := contentLength, int64(len(b)); e != a {
   387  			t.Errorf("expect %v to be %v", e, a)
   388  		}
   389  	}))
   390  
   391  	return server
   392  }
   393  
   394  func TestBuildContentLength_ZeroBody(t *testing.T) {
   395  	server := setupContentLengthTestServer(t, false, 0)
   396  	defer server.Close()
   397  
   398  	svc := s3.New(unit.Session, &aws.Config{
   399  		Endpoint:         aws.String(server.URL),
   400  		S3ForcePathStyle: aws.Bool(true),
   401  		DisableSSL:       aws.Bool(true),
   402  	})
   403  	_, err := svc.GetObject(&s3.GetObjectInput{
   404  		Bucket: aws.String("bucketname"),
   405  		Key:    aws.String("keyname"),
   406  	})
   407  
   408  	if err != nil {
   409  		t.Errorf("expect no error, got %v", err)
   410  	}
   411  }
   412  
   413  func TestBuildContentLength_NegativeBody(t *testing.T) {
   414  	server := setupContentLengthTestServer(t, false, 0)
   415  	defer server.Close()
   416  
   417  	svc := s3.New(unit.Session, &aws.Config{
   418  		Endpoint:         aws.String(server.URL),
   419  		S3ForcePathStyle: aws.Bool(true),
   420  		DisableSSL:       aws.Bool(true),
   421  	})
   422  	req, _ := svc.GetObjectRequest(&s3.GetObjectInput{
   423  		Bucket: aws.String("bucketname"),
   424  		Key:    aws.String("keyname"),
   425  	})
   426  
   427  	req.HTTPRequest.Header.Set("Content-Length", "-1")
   428  
   429  	if req.Error != nil {
   430  		t.Errorf("expect no error, got %v", req.Error)
   431  	}
   432  }
   433  
   434  func TestBuildContentLength_WithBody(t *testing.T) {
   435  	server := setupContentLengthTestServer(t, true, 1024)
   436  	defer server.Close()
   437  
   438  	svc := s3.New(unit.Session, &aws.Config{
   439  		Endpoint:         aws.String(server.URL),
   440  		S3ForcePathStyle: aws.Bool(true),
   441  		DisableSSL:       aws.Bool(true),
   442  	})
   443  	_, err := svc.PutObject(&s3.PutObjectInput{
   444  		Bucket: aws.String("bucketname"),
   445  		Key:    aws.String("keyname"),
   446  		Body:   bytes.NewReader(make([]byte, 1024)),
   447  	})
   448  
   449  	if err != nil {
   450  		t.Errorf("expect no error, got %v", err)
   451  	}
   452  }