github.com/trustbloc/kms-go@v1.1.2/kms/localkms/localkms_writer.go (about)

     1  /*
     2   Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4   SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package localkms
     8  
     9  import (
    10  	"encoding/base64"
    11  	"errors"
    12  	"fmt"
    13  
    14  	"github.com/google/tink/go/subtle/random"
    15  
    16  	kmsapi "github.com/trustbloc/kms-go/spi/kms"
    17  
    18  	"github.com/trustbloc/kms-go/kms"
    19  )
    20  
    21  const maxKeyIDLen = 50
    22  
    23  // newWriter creates a new instance of local storage key storeWriter in the given store and for primaryKeyURI.
    24  func newWriter(kmsStore kmsapi.Store, opts ...kmsapi.PrivateKeyOpts) *storeWriter {
    25  	pOpts := kmsapi.NewOpt()
    26  
    27  	for _, opt := range opts {
    28  		opt(pOpts)
    29  	}
    30  
    31  	return &storeWriter{
    32  		storage:           kmsStore,
    33  		requestedKeysetID: pOpts.KsID(),
    34  	}
    35  }
    36  
    37  // storeWriter struct to store a keyset in a local store.
    38  type storeWriter struct {
    39  	storage kmsapi.Store
    40  	//
    41  	requestedKeysetID string
    42  	// KeysetID is set when Write() is called
    43  	KeysetID string
    44  }
    45  
    46  // Write a marshaled keyset p in localstore with primaryKeyURI prefix + randomly generated KeysetID.
    47  func (l *storeWriter) Write(p []byte) (int, error) {
    48  	var err error
    49  
    50  	var ksID string
    51  
    52  	if l.requestedKeysetID != "" {
    53  		ksID, err = l.verifyRequestedID()
    54  		if err != nil {
    55  			return 0, err
    56  		}
    57  	} else {
    58  		ksID, err = l.newKeysetID()
    59  		if err != nil {
    60  			return 0, err
    61  		}
    62  	}
    63  
    64  	err = l.storage.Put(ksID, p)
    65  	if err != nil {
    66  		return 0, err
    67  	}
    68  
    69  	l.KeysetID = ksID
    70  
    71  	return len(p), nil
    72  }
    73  
    74  func (l *storeWriter) verifyRequestedID() (string, error) {
    75  	_, err := l.storage.Get(l.requestedKeysetID)
    76  	if errors.Is(err, kms.ErrKeyNotFound) {
    77  		return l.requestedKeysetID, nil
    78  	}
    79  
    80  	if err != nil {
    81  		return "", fmt.Errorf("got error while verifying requested ID: %w", err)
    82  	}
    83  
    84  	return "", fmt.Errorf("requested ID '%s' already exists, cannot write keyset", l.requestedKeysetID)
    85  }
    86  
    87  func (l *storeWriter) newKeysetID() (string, error) {
    88  	keySetIDLength := base64.RawURLEncoding.DecodedLen(maxKeyIDLen)
    89  
    90  	var ksID string
    91  
    92  	for {
    93  		// generate random ID
    94  		ksID = base64.RawURLEncoding.EncodeToString(random.GetRandomBytes(uint32(keySetIDLength)))
    95  
    96  		// ensure ksID is not already used
    97  		_, err := l.storage.Get(ksID)
    98  		if err != nil {
    99  			if errors.Is(err, kms.ErrKeyNotFound) {
   100  				break
   101  			}
   102  
   103  			return "", err
   104  		}
   105  	}
   106  
   107  	return ksID, nil
   108  }