github.com/trustbloc/kms-go@v1.1.2/kms/localkms/internal/keywrapper/kms_aead_test.go (about)

     1  /*
     2   Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4   SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package keywrapper
     8  
     9  import (
    10  	"encoding/base64"
    11  	"fmt"
    12  	"testing"
    13  
    14  	"github.com/google/tink/go/tink"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/trustbloc/kms-go/mock/secretlock"
    18  )
    19  
    20  func TestLocalKMS_New_AEAD(t *testing.T) {
    21  	// verify LocalAEAD implements tink.AEAD
    22  	require.Implements(t, (*tink.AEAD)(nil), (*LocalAEAD)(nil))
    23  
    24  	mockSecLck := &secretlock.MockSecretLock{
    25  		ValEncrypt: "successEncryption",
    26  		ValDecrypt: "successDecryption",
    27  	}
    28  
    29  	aeadKW, err := New(mockSecLck, "")
    30  	require.Error(t, err)
    31  	require.Empty(t, aeadKW)
    32  
    33  	aeadKW, err = New(mockSecLck, LocalKeyURIPrefix)
    34  	require.Error(t, err)
    35  	require.Empty(t, aeadKW)
    36  
    37  	validURIs := []string{
    38  		LocalKeyURIPrefix + "master/key",
    39  		"aws-kms://arn:aws:kms:ca-central-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f",
    40  		"gcp-kms://projects/aries-test-infrastructure/aead-key",
    41  	}
    42  	invalidURIs := []string{
    43  		"://master/key",
    44  		"master/key",
    45  		LocalKeyURIPrefix,
    46  		"aws-kms://",
    47  		"",
    48  	}
    49  
    50  	for _, invalidURI := range invalidURIs {
    51  		aeadKW, err = New(mockSecLck, invalidURI)
    52  		require.Error(t, err)
    53  		require.Empty(t, aeadKW)
    54  	}
    55  
    56  	for _, validURI := range validURIs {
    57  		aeadKW, err = New(mockSecLck, validURI)
    58  		require.NoError(t, err)
    59  		require.NotEmpty(t, aeadKW)
    60  	}
    61  }
    62  
    63  func TestLocalKMS_EncryptDecrypt(t *testing.T) {
    64  	flagTests := []struct {
    65  		tcName    string
    66  		encVal    []byte
    67  		errEncVal error
    68  		decVal    []byte
    69  		errDecVal error
    70  	}{
    71  		{
    72  			tcName: "success - valid aead, Encrypt and Decrypt",
    73  			encVal: []byte("loremIpsumCiphertext"),
    74  			decVal: []byte("loremIpsumPlainext"),
    75  		},
    76  		{
    77  			tcName:    "error - fail Encrypt/Decrypt",
    78  			errEncVal: fmt.Errorf("encryption failure"),
    79  			errDecVal: fmt.Errorf("decryption failure"),
    80  		},
    81  		{
    82  			tcName: "error - Encrypt fail base64URL.Decode ciphertext",
    83  			encVal: []byte("{}ciphertext"), // {} are illegal base64URL characters
    84  			decVal: []byte("loremIpsumPlaintext"),
    85  		},
    86  		{
    87  			tcName: "error - Decrypt fail base64URL.Decode plaintext",
    88  			encVal: []byte("loremIpsumCiphertext"),
    89  			decVal: []byte("{}plaintext"), // {} are illegal base64URL characters
    90  		},
    91  	}
    92  
    93  	validURI := LocalKeyURIPrefix + "master/key"
    94  
    95  	for _, tt := range flagTests {
    96  		t.Run(tt.tcName, func(t *testing.T) {
    97  			mockSecLck := &secretlock.MockSecretLock{
    98  				ErrEncrypt: tt.errEncVal,
    99  				ErrDecrypt: tt.errDecVal,
   100  			}
   101  
   102  			if tt.encVal != nil {
   103  				if tt.tcName != "error - Encrypt fail base64URL.Decode ciphertext" {
   104  					mockSecLck.ValEncrypt = base64.URLEncoding.EncodeToString(tt.encVal)
   105  				} else {
   106  					mockSecLck.ValEncrypt = string(tt.encVal)
   107  				}
   108  			}
   109  
   110  			if tt.decVal != nil {
   111  				if tt.tcName != "error - Decrypt fail base64URL.Decode plaintext" {
   112  					mockSecLck.ValDecrypt = base64.URLEncoding.EncodeToString(tt.decVal)
   113  				} else {
   114  					mockSecLck.ValDecrypt = string(tt.decVal)
   115  				}
   116  			}
   117  
   118  			aeadKW, err := New(mockSecLck, validURI)
   119  			require.NoError(t, err)
   120  			require.NotEmpty(t, aeadKW)
   121  
   122  			// Encrypt() calls secretLock.Encrypt() which is mocked.
   123  			// Only validate if aeadKW returns the mocked value
   124  			ct, err := aeadKW.Encrypt([]byte(""), []byte(""))
   125  			if tt.tcName == "error - Encrypt fail base64URL.Decode ciphertext" {
   126  				require.Nil(t, ct)
   127  				require.EqualError(t, err,
   128  					base64.CorruptInputError(0).Error()) // 0 for index of '{' in "{}ciphertext" test case above
   129  			} else {
   130  				require.EqualValues(t, err, tt.errEncVal)
   131  				require.Equal(t, tt.encVal, ct)
   132  			}
   133  
   134  			// same as Decrypt above
   135  			dec, err := aeadKW.Decrypt([]byte(""), []byte(""))
   136  			if tt.tcName == "error - Decrypt fail base64URL.Decode plaintext" {
   137  				require.Nil(t, dec)
   138  				require.EqualError(t, err,
   139  					base64.CorruptInputError(0).Error()) // 0 for index of '{' in "{}plaintext" test case above
   140  			} else {
   141  				require.Equal(t, tt.decVal, dec)
   142  				require.EqualValues(t, err, tt.errDecVal)
   143  			}
   144  		})
   145  	}
   146  }