github.com/trustbloc/kms-go@v1.1.2/kms/localkms/internal/keywrapper/kms_aead_test.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package keywrapper 8 9 import ( 10 "encoding/base64" 11 "fmt" 12 "testing" 13 14 "github.com/google/tink/go/tink" 15 "github.com/stretchr/testify/require" 16 17 "github.com/trustbloc/kms-go/mock/secretlock" 18 ) 19 20 func TestLocalKMS_New_AEAD(t *testing.T) { 21 // verify LocalAEAD implements tink.AEAD 22 require.Implements(t, (*tink.AEAD)(nil), (*LocalAEAD)(nil)) 23 24 mockSecLck := &secretlock.MockSecretLock{ 25 ValEncrypt: "successEncryption", 26 ValDecrypt: "successDecryption", 27 } 28 29 aeadKW, err := New(mockSecLck, "") 30 require.Error(t, err) 31 require.Empty(t, aeadKW) 32 33 aeadKW, err = New(mockSecLck, LocalKeyURIPrefix) 34 require.Error(t, err) 35 require.Empty(t, aeadKW) 36 37 validURIs := []string{ 38 LocalKeyURIPrefix + "master/key", 39 "aws-kms://arn:aws:kms:ca-central-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f", 40 "gcp-kms://projects/aries-test-infrastructure/aead-key", 41 } 42 invalidURIs := []string{ 43 "://master/key", 44 "master/key", 45 LocalKeyURIPrefix, 46 "aws-kms://", 47 "", 48 } 49 50 for _, invalidURI := range invalidURIs { 51 aeadKW, err = New(mockSecLck, invalidURI) 52 require.Error(t, err) 53 require.Empty(t, aeadKW) 54 } 55 56 for _, validURI := range validURIs { 57 aeadKW, err = New(mockSecLck, validURI) 58 require.NoError(t, err) 59 require.NotEmpty(t, aeadKW) 60 } 61 } 62 63 func TestLocalKMS_EncryptDecrypt(t *testing.T) { 64 flagTests := []struct { 65 tcName string 66 encVal []byte 67 errEncVal error 68 decVal []byte 69 errDecVal error 70 }{ 71 { 72 tcName: "success - valid aead, Encrypt and Decrypt", 73 encVal: []byte("loremIpsumCiphertext"), 74 decVal: []byte("loremIpsumPlainext"), 75 }, 76 { 77 tcName: "error - fail Encrypt/Decrypt", 78 errEncVal: fmt.Errorf("encryption failure"), 79 errDecVal: fmt.Errorf("decryption failure"), 80 }, 81 { 82 tcName: "error - Encrypt fail base64URL.Decode ciphertext", 83 encVal: []byte("{}ciphertext"), // {} are illegal base64URL characters 84 decVal: []byte("loremIpsumPlaintext"), 85 }, 86 { 87 tcName: "error - Decrypt fail base64URL.Decode plaintext", 88 encVal: []byte("loremIpsumCiphertext"), 89 decVal: []byte("{}plaintext"), // {} are illegal base64URL characters 90 }, 91 } 92 93 validURI := LocalKeyURIPrefix + "master/key" 94 95 for _, tt := range flagTests { 96 t.Run(tt.tcName, func(t *testing.T) { 97 mockSecLck := &secretlock.MockSecretLock{ 98 ErrEncrypt: tt.errEncVal, 99 ErrDecrypt: tt.errDecVal, 100 } 101 102 if tt.encVal != nil { 103 if tt.tcName != "error - Encrypt fail base64URL.Decode ciphertext" { 104 mockSecLck.ValEncrypt = base64.URLEncoding.EncodeToString(tt.encVal) 105 } else { 106 mockSecLck.ValEncrypt = string(tt.encVal) 107 } 108 } 109 110 if tt.decVal != nil { 111 if tt.tcName != "error - Decrypt fail base64URL.Decode plaintext" { 112 mockSecLck.ValDecrypt = base64.URLEncoding.EncodeToString(tt.decVal) 113 } else { 114 mockSecLck.ValDecrypt = string(tt.decVal) 115 } 116 } 117 118 aeadKW, err := New(mockSecLck, validURI) 119 require.NoError(t, err) 120 require.NotEmpty(t, aeadKW) 121 122 // Encrypt() calls secretLock.Encrypt() which is mocked. 123 // Only validate if aeadKW returns the mocked value 124 ct, err := aeadKW.Encrypt([]byte(""), []byte("")) 125 if tt.tcName == "error - Encrypt fail base64URL.Decode ciphertext" { 126 require.Nil(t, ct) 127 require.EqualError(t, err, 128 base64.CorruptInputError(0).Error()) // 0 for index of '{' in "{}ciphertext" test case above 129 } else { 130 require.EqualValues(t, err, tt.errEncVal) 131 require.Equal(t, tt.encVal, ct) 132 } 133 134 // same as Decrypt above 135 dec, err := aeadKW.Decrypt([]byte(""), []byte("")) 136 if tt.tcName == "error - Decrypt fail base64URL.Decode plaintext" { 137 require.Nil(t, dec) 138 require.EqualError(t, err, 139 base64.CorruptInputError(0).Error()) // 0 for index of '{' in "{}plaintext" test case above 140 } else { 141 require.Equal(t, tt.decVal, dec) 142 require.EqualValues(t, err, tt.errDecVal) 143 } 144 }) 145 } 146 }