github.com/aavshr/aws-sdk-go@v1.41.3/service/sts/customizations_test.go (about)

     1  package sts_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"testing"
     9  
    10  	"github.com/aavshr/aws-sdk-go/aws"
    11  	"github.com/aavshr/aws-sdk-go/aws/corehandlers"
    12  	"github.com/aavshr/aws-sdk-go/aws/request"
    13  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    14  	"github.com/aavshr/aws-sdk-go/service/sts"
    15  )
    16  
    17  var svc = sts.New(unit.Session, &aws.Config{
    18  	Region: aws.String("mock-region"),
    19  })
    20  
    21  func TestUnsignedRequest_AssumeRoleWithSAML(t *testing.T) {
    22  	req, _ := svc.AssumeRoleWithSAMLRequest(&sts.AssumeRoleWithSAMLInput{
    23  		PrincipalArn:  aws.String("ARN01234567890123456789"),
    24  		RoleArn:       aws.String("ARN01234567890123456789"),
    25  		SAMLAssertion: aws.String("ASSERT"),
    26  	})
    27  
    28  	err := req.Sign()
    29  	if err != nil {
    30  		t.Errorf("expect no error, got %v", err)
    31  	}
    32  	if e, a := "", req.HTTPRequest.Header.Get("Authorization"); e != a {
    33  		t.Errorf("expect %v, got %v", e, a)
    34  	}
    35  }
    36  
    37  func TestUnsignedRequest_AssumeRoleWithWebIdentity(t *testing.T) {
    38  	req, _ := svc.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
    39  		RoleArn:          aws.String("ARN01234567890123456789"),
    40  		RoleSessionName:  aws.String("SESSION"),
    41  		WebIdentityToken: aws.String("TOKEN"),
    42  	})
    43  
    44  	err := req.Sign()
    45  	if err != nil {
    46  		t.Errorf("expect no error, got %v", err)
    47  	}
    48  	if e, a := "", req.HTTPRequest.Header.Get("Authorization"); e != a {
    49  		t.Errorf("expect %v, got %v", e, a)
    50  	}
    51  }
    52  
    53  func TestSTSCustomRetryErrorCodes(t *testing.T) {
    54  	svc := sts.New(unit.Session, &aws.Config{
    55  		MaxRetries: aws.Int(1),
    56  	})
    57  	svc.Handlers.Validate.Clear()
    58  
    59  	const xmlErr = `<ErrorResponse><Error><Code>%s</Code><Message>some error message</Message></Error></ErrorResponse>`
    60  	var reqCount int
    61  	resps := []*http.Response{
    62  		{
    63  			StatusCode: 400,
    64  			Header:     http.Header{},
    65  			Body: ioutil.NopCloser(bytes.NewReader(
    66  				[]byte(fmt.Sprintf(xmlErr, sts.ErrCodeIDPCommunicationErrorException)),
    67  			)),
    68  		},
    69  		{
    70  			StatusCode: 200,
    71  			Header:     http.Header{},
    72  			Body:       ioutil.NopCloser(bytes.NewReader([]byte{})),
    73  		},
    74  	}
    75  
    76  	req, _ := svc.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{})
    77  	req.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{
    78  		Name: "custom send handler",
    79  		Fn: func(r *request.Request) {
    80  			r.HTTPResponse = resps[reqCount]
    81  			reqCount++
    82  		},
    83  	})
    84  
    85  	if err := req.Send(); err != nil {
    86  		t.Fatalf("expect no error, got %v", err)
    87  	}
    88  
    89  	if e, a := 2, reqCount; e != a {
    90  		t.Errorf("expect %v requests, got %v", e, a)
    91  	}
    92  }