github.com/aavshr/aws-sdk-go@v1.41.3/aws/signer/v4/stream_test.go (about)

     1  //go:build go1.7
     2  // +build go1.7
     3  
     4  package v4
     5  
     6  import (
     7  	"encoding/hex"
     8  	"fmt"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    14  )
    15  
    16  type periodicBadCredentials struct {
    17  	call        int
    18  	credentials *credentials.Credentials
    19  }
    20  
    21  func (p *periodicBadCredentials) Get() (credentials.Value, error) {
    22  	defer func() {
    23  		p.call++
    24  	}()
    25  
    26  	if p.call%2 == 0 {
    27  		return credentials.Value{}, fmt.Errorf("credentials error")
    28  	}
    29  
    30  	return p.credentials.Get()
    31  }
    32  
    33  type chunk struct {
    34  	headers, payload []byte
    35  }
    36  
    37  func mustDecodeHex(b []byte, err error) []byte {
    38  	if err != nil {
    39  		panic(err)
    40  	}
    41  
    42  	return b
    43  }
    44  
    45  func TestStreamingChunkSigner(t *testing.T) {
    46  	const (
    47  		region        = "us-east-1"
    48  		service       = "transcribe"
    49  		seedSignature = "9d9ab996c81f32c9d4e6fc166c92584f3741d1cb5ce325cd11a77d1f962c8de2"
    50  	)
    51  
    52  	staticCredentials := credentials.NewStaticCredentials("AKIDEXAMPLE", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "")
    53  	currentTime := time.Date(2019, 1, 27, 22, 37, 54, 0, time.UTC)
    54  
    55  	cases := map[string]struct {
    56  		credentials        credentialValueProvider
    57  		chunks             []chunk
    58  		expectedSignatures map[int]string
    59  		expectedErrors     map[int]string
    60  	}{
    61  		"signature calculation": {
    62  			credentials: staticCredentials,
    63  			chunks: []chunk{
    64  				{headers: []byte("headers"), payload: []byte("payload")},
    65  				{headers: []byte("more headers"), payload: []byte("more payload")},
    66  			},
    67  			expectedSignatures: map[int]string{
    68  				0: "681a7eaa82891536f24af7ec7e9219ee251ccd9bac2f1b981eab7c5ec8579115",
    69  				1: "07633d9d4ab4d81634a2164934d1f648c7cbc6839a8cf0773d818127a267e4d6",
    70  			},
    71  		},
    72  		"signature calculation errors": {
    73  			credentials: &periodicBadCredentials{credentials: staticCredentials},
    74  			chunks: []chunk{
    75  				{headers: []byte("headers"), payload: []byte("payload")},
    76  				{headers: []byte("headers"), payload: []byte("payload")},
    77  				{headers: []byte("more headers"), payload: []byte("more payload")},
    78  				{headers: []byte("more headers"), payload: []byte("more payload")},
    79  			},
    80  			expectedSignatures: map[int]string{
    81  				1: "681a7eaa82891536f24af7ec7e9219ee251ccd9bac2f1b981eab7c5ec8579115",
    82  				3: "07633d9d4ab4d81634a2164934d1f648c7cbc6839a8cf0773d818127a267e4d6",
    83  			},
    84  			expectedErrors: map[int]string{
    85  				0: "credentials error",
    86  				2: "credentials error",
    87  			},
    88  		},
    89  	}
    90  
    91  	for name, tt := range cases {
    92  		t.Run(name, func(t *testing.T) {
    93  			chunkSigner := &StreamSigner{
    94  				region:      region,
    95  				service:     service,
    96  				credentials: tt.credentials,
    97  				prevSig:     mustDecodeHex(hex.DecodeString(seedSignature)),
    98  			}
    99  
   100  			for i, chunk := range tt.chunks {
   101  				var expectedError string
   102  				if len(tt.expectedErrors) != 0 {
   103  					_, ok := tt.expectedErrors[i]
   104  					if ok {
   105  						expectedError = tt.expectedErrors[i]
   106  					}
   107  				}
   108  
   109  				signature, err := chunkSigner.GetSignature(chunk.headers, chunk.payload, currentTime)
   110  				if err == nil && len(expectedError) > 0 {
   111  					t.Errorf("expected error, but got nil")
   112  					continue
   113  				} else if err != nil && len(expectedError) == 0 {
   114  					t.Errorf("expected no error, but got %v", err)
   115  					continue
   116  				} else if err != nil && len(expectedError) > 0 && !strings.Contains(err.Error(), expectedError) {
   117  					t.Errorf("expected %v, but got %v", expectedError, err)
   118  					continue
   119  				} else if len(expectedError) > 0 {
   120  					continue
   121  				}
   122  
   123  				expectedSignature, ok := tt.expectedSignatures[i]
   124  				if !ok {
   125  					t.Fatalf("expected signature not provided for test case")
   126  				}
   127  
   128  				if e, a := expectedSignature, hex.EncodeToString(signature); e != a {
   129  					t.Errorf("expected %v, got %v", e, a)
   130  				}
   131  			}
   132  		})
   133  	}
   134  }