github.com/hyperledger/aries-framework-go@v0.3.2/pkg/store/connection/connection_recorder.go (about)

     1  /*
     2   *
     3   * Copyright SecureKey Technologies Inc. All Rights Reserved.
     4   *
     5   * SPDX-License-Identifier: Apache-2.0
     6   * /
     7   *
     8   */
     9  
    10  package connection
    11  
    12  import (
    13  	"crypto"
    14  	"encoding/json"
    15  	"errors"
    16  	"fmt"
    17  
    18  	"github.com/hyperledger/aries-framework-go/spi/storage"
    19  )
    20  
    21  const (
    22  	// StateNameCompleted completed state.
    23  	StateNameCompleted = "completed"
    24  	// MyNSPrefix namespace val my.
    25  	MyNSPrefix = "my"
    26  	// TheirNSPrefix namespace val their
    27  	// TODO: https://github.com/hyperledger/aries-framework-go/issues/556 It will not be constant, this namespace
    28  	//  will need to be figured with verification key
    29  	TheirNSPrefix    = "their"
    30  	errMsgInvalidKey = "invalid key"
    31  )
    32  
    33  // NewRecorder returns new connection recorder.
    34  // Recorder is read-write connection store which provides
    35  // write features on top query features from Lookup.
    36  func NewRecorder(p provider) (*Recorder, error) {
    37  	lookup, err := NewLookup(p)
    38  	if err != nil {
    39  		return nil, fmt.Errorf("failed to create new connection recorder : %w", err)
    40  	}
    41  
    42  	return &Recorder{lookup}, nil
    43  }
    44  
    45  // Recorder is read-write connection store.
    46  type Recorder struct {
    47  	*Lookup
    48  }
    49  
    50  // SaveInvitation saves invitation in permanent store for given key.
    51  // TODO should avoid using target of type `interface{}` [Issue #1030].
    52  func (c *Recorder) SaveInvitation(id string, invitation interface{}) error {
    53  	if id == "" {
    54  		return fmt.Errorf(errMsgInvalidKey)
    55  	}
    56  
    57  	return marshalAndSave(getInvitationKeyPrefix()(id), invitation, c.store)
    58  }
    59  
    60  // SaveOOBv2Invitation saves OOBv2 invitation in permanent store under given ID.
    61  // TODO should avoid using target of type `interface{}` [Issue #1030].
    62  func (c *Recorder) SaveOOBv2Invitation(myDID string, invitation interface{}) error {
    63  	if myDID == "" {
    64  		return fmt.Errorf(errMsgInvalidKey)
    65  	}
    66  
    67  	return marshalAndSave(getOOBInvitationV2KeyPrefix()(tagValueFromDIDs(myDID)), invitation, c.store)
    68  }
    69  
    70  // SaveConnectionRecord saves given connection records in underlying store.
    71  func (c *Recorder) SaveConnectionRecord(record *Record) error {
    72  	if err := marshalAndSave(getConnectionKeyPrefix()(record.ConnectionID),
    73  		record, c.protocolStateStore, storage.Tag{
    74  			Name:  getConnectionKeyPrefix()(""),
    75  			Value: getConnectionKeyPrefix()(record.ConnectionID),
    76  		}); err != nil {
    77  		return fmt.Errorf("save connection record in protocol state store: %w", err)
    78  	}
    79  
    80  	if record.State != "" {
    81  		err := marshalAndSave(getConnectionStateKeyPrefix()(record.ConnectionID, record.State),
    82  			record, c.protocolStateStore, storage.Tag{
    83  				Name:  connStateKeyPrefix,
    84  				Value: getConnectionStateKeyPrefix()(record.ConnectionID),
    85  			})
    86  		if err != nil {
    87  			return fmt.Errorf("save connection record with state in protocol state store: %w", err)
    88  		}
    89  	}
    90  
    91  	if record.State == StateNameCompleted {
    92  		if err := marshalAndSave(getConnectionKeyPrefix()(record.ConnectionID),
    93  			record, c.store, storage.Tag{
    94  				Name:  getConnectionKeyPrefix()(""),
    95  				Value: getConnectionKeyPrefix()(record.ConnectionID),
    96  			},
    97  			storage.Tag{
    98  				Name:  bothDIDsTagName,
    99  				Value: tagValueFromDIDs(record.MyDID, record.TheirDID),
   100  			},
   101  			storage.Tag{
   102  				Name:  theirDIDTagName,
   103  				Value: tagValueFromDIDs(record.TheirDID),
   104  			}); err != nil {
   105  			return fmt.Errorf("save connection record in permanent store: %w", err)
   106  		}
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  // SaveConnectionRecordWithMappings saves newly created connection record against the connection id in the store
   113  // and it creates mapping from namespaced ThreadID to connection ID.
   114  func (c *Recorder) SaveConnectionRecordWithMappings(record *Record) error {
   115  	err := isValidConnection(record)
   116  	if err != nil {
   117  		return fmt.Errorf("validation failed while saving connection record with mapping: %w", err)
   118  	}
   119  
   120  	err = c.SaveConnectionRecord(record)
   121  	if err != nil {
   122  		return fmt.Errorf("failed to save connection record with mappings: %w", err)
   123  	}
   124  
   125  	err = c.SaveNamespaceThreadID(record.ThreadID, record.Namespace, record.ConnectionID)
   126  	if err != nil {
   127  		return fmt.Errorf("failed to save connection record with namespace mappings: %w", err)
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  // SaveEvent saves event related data for given connection ID
   134  // TODO connection event data shouldn't be transient [Issues #1029].
   135  func (c *Recorder) SaveEvent(connectionID string, data []byte) error {
   136  	return c.protocolStateStore.Put(getEventDataKeyPrefix()(connectionID), data)
   137  }
   138  
   139  // SaveNamespaceThreadID saves given namespace, threadID and connection ID mapping in protocol state store.
   140  func (c *Recorder) SaveNamespaceThreadID(threadID, namespace, connectionID string) error {
   141  	if namespace != MyNSPrefix && namespace != TheirNSPrefix {
   142  		return fmt.Errorf("namespace not supported")
   143  	}
   144  
   145  	prefix := MyNSPrefix
   146  	if namespace == TheirNSPrefix {
   147  		prefix = TheirNSPrefix
   148  	}
   149  
   150  	key, err := computeHash([]byte(threadID))
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	return c.protocolStateStore.Put(getNamespaceKeyPrefix(prefix)(key), []byte(connectionID))
   156  }
   157  
   158  // RemoveConnection removes connection record from the store for given id.
   159  func (c *Recorder) RemoveConnection(connectionID string) error {
   160  	record, err := c.GetConnectionRecord(connectionID)
   161  	if err != nil {
   162  		return fmt.Errorf("unable to get connection record: connectionid=%s err=%w", connectionID, err)
   163  	}
   164  
   165  	if err = c.protocolStateStore.Delete(getConnectionKeyPrefix()(connectionID)); err != nil {
   166  		return fmt.Errorf("unable to delete connection record from the protocol state store: connectionid=%s err=%w",
   167  			connectionID, err)
   168  	}
   169  
   170  	// remove connection records for different states from protocol state store
   171  	err = removeConnectionsForStates(c, connectionID)
   172  	if err != nil {
   173  		return fmt.Errorf("remove records for different connections states error: %w", err)
   174  	}
   175  
   176  	err = c.store.Delete(getConnectionKeyPrefix()(connectionID))
   177  	if err != nil {
   178  		return fmt.Errorf("unable to delete connection record from the store: connectionid=%s err=%w", connectionID, err)
   179  	}
   180  
   181  	// remove namespace, threadID and connection ID mapping from protocol state store
   182  	err = removeMappings(c, record)
   183  	if err != nil {
   184  		return fmt.Errorf("unable to delete connection record with namespace mappings: %w", err)
   185  	}
   186  
   187  	return nil
   188  }
   189  
   190  func marshalAndSave(k string, v interface{}, store storage.Store, tags ...storage.Tag) error {
   191  	bytes, err := json.Marshal(v)
   192  	if err != nil {
   193  		return fmt.Errorf("save connection record: %w", err)
   194  	}
   195  
   196  	return store.Put(k, bytes, tags...)
   197  }
   198  
   199  // isValidConnection validates connection record.
   200  func isValidConnection(r *Record) error {
   201  	if r.ThreadID == "" || r.ConnectionID == "" || r.Namespace == "" {
   202  		return fmt.Errorf("input parameters thid : %s and connectionId : %s namespace : %s cannot be empty",
   203  			r.ThreadID, r.ConnectionID, r.Namespace)
   204  	}
   205  
   206  	return nil
   207  }
   208  
   209  // computeHash will compute the hash for the supplied bytes.
   210  func computeHash(bytes []byte) (string, error) {
   211  	if len(bytes) == 0 {
   212  		return "", errors.New("unable to compute hash, empty bytes")
   213  	}
   214  
   215  	h := crypto.SHA256.New()
   216  	hash := h.Sum(bytes)
   217  
   218  	return fmt.Sprintf("%x", hash), nil
   219  }
   220  
   221  func removeConnectionsForStates(c *Recorder, connectionID string) error {
   222  	itr, err := c.protocolStateStore.Query(fmt.Sprintf("%s:%s", connStateKeyPrefix,
   223  		getConnectionStateKeyPrefix()(connectionID)))
   224  	if err != nil {
   225  		return fmt.Errorf("failed to query protocol state store: %w", err)
   226  	}
   227  
   228  	defer func() {
   229  		errClose := itr.Close()
   230  		if errClose != nil {
   231  			logger.Errorf("failed to close iterator: %s", errClose.Error())
   232  		}
   233  	}()
   234  
   235  	more, err := itr.Next()
   236  	if err != nil {
   237  		return fmt.Errorf("failed to get next set of data from iterator: %w", err)
   238  	}
   239  
   240  	for more {
   241  		key, err := itr.Key()
   242  		if err != nil {
   243  			return fmt.Errorf("failed to get key from iterator: %w", err)
   244  		}
   245  
   246  		err = c.protocolStateStore.Delete(key)
   247  		if err != nil {
   248  			return fmt.Errorf(
   249  				"unable to delete connection state record from the protocol state store: key=%s connectionid=%s err=%w",
   250  				key, connectionID, err)
   251  		}
   252  
   253  		more, err = itr.Next()
   254  		if err != nil {
   255  			return fmt.Errorf("failed to get next set of data from iterator: %w", err)
   256  		}
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  func removeMappings(c *Recorder, record *Record) error {
   263  	key, err := computeHash([]byte(record.ThreadID))
   264  	if err != nil {
   265  		return fmt.Errorf("compute hash: %w", err)
   266  	}
   267  
   268  	return c.store.Delete(getNamespaceKeyPrefix(record.Namespace)(key))
   269  }