github.com/trustbloc/kms-go@v1.1.2/kms/localkms/localkms_writer.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package localkms 8 9 import ( 10 "encoding/base64" 11 "errors" 12 "fmt" 13 14 "github.com/google/tink/go/subtle/random" 15 16 kmsapi "github.com/trustbloc/kms-go/spi/kms" 17 18 "github.com/trustbloc/kms-go/kms" 19 ) 20 21 const maxKeyIDLen = 50 22 23 // newWriter creates a new instance of local storage key storeWriter in the given store and for primaryKeyURI. 24 func newWriter(kmsStore kmsapi.Store, opts ...kmsapi.PrivateKeyOpts) *storeWriter { 25 pOpts := kmsapi.NewOpt() 26 27 for _, opt := range opts { 28 opt(pOpts) 29 } 30 31 return &storeWriter{ 32 storage: kmsStore, 33 requestedKeysetID: pOpts.KsID(), 34 } 35 } 36 37 // storeWriter struct to store a keyset in a local store. 38 type storeWriter struct { 39 storage kmsapi.Store 40 // 41 requestedKeysetID string 42 // KeysetID is set when Write() is called 43 KeysetID string 44 } 45 46 // Write a marshaled keyset p in localstore with primaryKeyURI prefix + randomly generated KeysetID. 47 func (l *storeWriter) Write(p []byte) (int, error) { 48 var err error 49 50 var ksID string 51 52 if l.requestedKeysetID != "" { 53 ksID, err = l.verifyRequestedID() 54 if err != nil { 55 return 0, err 56 } 57 } else { 58 ksID, err = l.newKeysetID() 59 if err != nil { 60 return 0, err 61 } 62 } 63 64 err = l.storage.Put(ksID, p) 65 if err != nil { 66 return 0, err 67 } 68 69 l.KeysetID = ksID 70 71 return len(p), nil 72 } 73 74 func (l *storeWriter) verifyRequestedID() (string, error) { 75 _, err := l.storage.Get(l.requestedKeysetID) 76 if errors.Is(err, kms.ErrKeyNotFound) { 77 return l.requestedKeysetID, nil 78 } 79 80 if err != nil { 81 return "", fmt.Errorf("got error while verifying requested ID: %w", err) 82 } 83 84 return "", fmt.Errorf("requested ID '%s' already exists, cannot write keyset", l.requestedKeysetID) 85 } 86 87 func (l *storeWriter) newKeysetID() (string, error) { 88 keySetIDLength := base64.RawURLEncoding.DecodedLen(maxKeyIDLen) 89 90 var ksID string 91 92 for { 93 // generate random ID 94 ksID = base64.RawURLEncoding.EncodeToString(random.GetRandomBytes(uint32(keySetIDLength))) 95 96 // ensure ksID is not already used 97 _, err := l.storage.Get(ksID) 98 if err != nil { 99 if errors.Is(err, kms.ErrKeyNotFound) { 100 break 101 } 102 103 return "", err 104 } 105 } 106 107 return ksID, nil 108 }