github.com/hyperledger/aries-framework-go@v0.3.2/pkg/store/connection/connection_lookup_test.go (about)

     1  /*
     2   *
     3   * Copyright SecureKey Technologies Inc. All Rights Reserved.
     4   *
     5   * SPDX-License-Identifier: Apache-2.0
     6   * /
     7   *
     8   */
     9  
    10  package connection
    11  
    12  import (
    13  	"encoding/json"
    14  	"fmt"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/require"
    18  
    19  	"github.com/hyperledger/aries-framework-go/component/storageutil/mem"
    20  	mockstorage "github.com/hyperledger/aries-framework-go/pkg/mock/storage"
    21  	"github.com/hyperledger/aries-framework-go/spi/storage"
    22  )
    23  
    24  const (
    25  	threadIDFmt  = "thID-%v"
    26  	connIDFmt    = "connValue-%v"
    27  	sampleErrMsg = "sample-error-message"
    28  )
    29  
    30  func TestNewConnectionReader(t *testing.T) {
    31  	t.Run("create new connection reader", func(t *testing.T) {
    32  		lookup, err := NewLookup(&mockProvider{})
    33  		require.NoError(t, err)
    34  		require.NotNil(t, lookup)
    35  		require.NotNil(t, lookup.protocolStateStore)
    36  		require.NotNil(t, lookup.store)
    37  	})
    38  
    39  	t.Run("create new connection reader failure due to protocol state store error", func(t *testing.T) {
    40  		lookup, err := NewLookup(&mockProvider{protocolStateStoreError: fmt.Errorf(sampleErrMsg)})
    41  		require.Error(t, err)
    42  		require.Contains(t, err.Error(), sampleErrMsg)
    43  		require.Nil(t, lookup)
    44  	})
    45  
    46  	t.Run("create new connection reader failure due to protocol state store config error", func(t *testing.T) {
    47  		lookup, err := NewLookup(&mockProvider{protocolStoreConfError: fmt.Errorf(sampleErrMsg)})
    48  		require.Error(t, err)
    49  		require.Contains(t, err.Error(), sampleErrMsg)
    50  		require.Nil(t, lookup)
    51  	})
    52  
    53  	t.Run("create new connection reader failure due to store error", func(t *testing.T) {
    54  		lookup, err := NewLookup(&mockProvider{storeError: fmt.Errorf(sampleErrMsg)})
    55  		require.Error(t, err)
    56  		require.Contains(t, err.Error(), sampleErrMsg)
    57  		require.Nil(t, lookup)
    58  	})
    59  
    60  	t.Run("create new connection reader failure due to store config error", func(t *testing.T) {
    61  		lookup, err := NewLookup(&mockProvider{storeConfError: fmt.Errorf(sampleErrMsg)})
    62  		require.Error(t, err)
    63  		require.Contains(t, err.Error(), sampleErrMsg)
    64  		require.Nil(t, lookup)
    65  	})
    66  }
    67  
    68  func TestConnectionReader_GetAndQueryConnectionRecord(t *testing.T) {
    69  	const noOfItems = 12
    70  	connectionIDS := make([]string, noOfItems)
    71  
    72  	for i := 0; i < noOfItems; i++ {
    73  		connectionIDS[i] = fmt.Sprintf(connIDFmt, i)
    74  	}
    75  
    76  	saveInStore := func(store storage.Store, ids []string) {
    77  		for _, id := range ids {
    78  			connRecBytes, err := json.Marshal(&Record{
    79  				ConnectionID: id,
    80  				ThreadID:     fmt.Sprintf(threadIDFmt, id),
    81  			})
    82  			require.NoError(t, err)
    83  			err = store.Put(getConnectionKeyPrefix()(id), connRecBytes, storage.Tag{Name: "conn_"})
    84  			require.NoError(t, err)
    85  		}
    86  	}
    87  
    88  	t.Run("get connection record - from store", func(t *testing.T) {
    89  		lookup, e := NewLookup(&mockProvider{})
    90  		require.NoError(t, e)
    91  		require.NotNil(t, lookup)
    92  
    93  		for _, connectionID := range connectionIDS {
    94  			connection, err := lookup.GetConnectionRecord(connectionID)
    95  			require.Error(t, err)
    96  			require.Equal(t, err, storage.ErrDataNotFound)
    97  			require.Nil(t, connection)
    98  		}
    99  
   100  		// prepare data
   101  		saveInStore(lookup.store, connectionIDS)
   102  
   103  		for _, connectionID := range connectionIDS {
   104  			connection, err := lookup.GetConnectionRecord(connectionID)
   105  			require.NoError(t, err)
   106  			require.NotNil(t, connection)
   107  			require.Equal(t, connectionID, connection.ConnectionID)
   108  			require.Equal(t, fmt.Sprintf(threadIDFmt, connectionID), connection.ThreadID)
   109  		}
   110  
   111  		records, e := lookup.QueryConnectionRecords()
   112  		require.NoError(t, e)
   113  		require.NotEmpty(t, records)
   114  		require.Len(t, records, noOfItems)
   115  	})
   116  
   117  	t.Run("get connection record - from protocol state store", func(t *testing.T) {
   118  		lookup, e := NewLookup(&mockProvider{})
   119  		require.NoError(t, e)
   120  		require.NotNil(t, lookup)
   121  
   122  		for _, connectionID := range connectionIDS {
   123  			connection, err := lookup.GetConnectionRecord(connectionID)
   124  			require.Error(t, err)
   125  			require.Equal(t, err, storage.ErrDataNotFound)
   126  			require.Nil(t, connection)
   127  		}
   128  
   129  		// prepare data
   130  		saveInStore(lookup.protocolStateStore, connectionIDS)
   131  
   132  		for _, connectionID := range connectionIDS {
   133  			connection, err := lookup.GetConnectionRecord(connectionID)
   134  			require.NoError(t, err)
   135  			require.NotNil(t, connection)
   136  			require.Equal(t, connectionID, connection.ConnectionID)
   137  			require.Equal(t, fmt.Sprintf(threadIDFmt, connectionID), connection.ThreadID)
   138  		}
   139  
   140  		records, e := lookup.QueryConnectionRecords()
   141  		require.NoError(t, e)
   142  		require.NotEmpty(t, records)
   143  		require.Len(t, records, noOfItems)
   144  	})
   145  
   146  	t.Run("get connection record - error scenario", func(t *testing.T) {
   147  		provider := &mockProvider{}
   148  		provider.store = &mockstorage.MockStore{
   149  			ErrGet: fmt.Errorf(sampleErrMsg),
   150  			Store:  make(map[string]mockstorage.DBEntry),
   151  		}
   152  		lookup, err := NewLookup(provider)
   153  		require.NoError(t, err)
   154  		require.NotNil(t, lookup)
   155  
   156  		// prepare data
   157  		saveInStore(lookup.protocolStateStore, connectionIDS)
   158  
   159  		for _, connectionID := range connectionIDS {
   160  			connection, err := lookup.GetConnectionRecord(connectionID)
   161  			require.Error(t, err)
   162  			require.Nil(t, connection)
   163  			require.EqualError(t, err, sampleErrMsg)
   164  		}
   165  	})
   166  }
   167  
   168  func TestConnectionReader_GetConnectionRecordAtState(t *testing.T) {
   169  	const state = "requested"
   170  
   171  	const noOfItems = 12
   172  
   173  	connectionIDS := make([]string, noOfItems)
   174  
   175  	for i := 0; i < noOfItems; i++ {
   176  		connectionIDS[i] = fmt.Sprintf(connIDFmt, i)
   177  	}
   178  
   179  	saveInStore := func(store storage.Store, ids []string) {
   180  		for _, id := range ids {
   181  			connRecBytes, err := json.Marshal(&Record{
   182  				ConnectionID: id,
   183  				ThreadID:     fmt.Sprintf(threadIDFmt, id),
   184  			})
   185  			require.NoError(t, err)
   186  			err = store.Put(getConnectionStateKeyPrefix()(id, state), connRecBytes)
   187  			require.NoError(t, err)
   188  		}
   189  	}
   190  
   191  	t.Run("get connection record at state", func(t *testing.T) {
   192  		store, err := NewLookup(&mockProvider{})
   193  		require.NoError(t, err)
   194  		require.NotNil(t, store)
   195  
   196  		// should fail since data doesn't exists
   197  		for _, connectionID := range connectionIDS {
   198  			connection, err := store.GetConnectionRecordAtState(connectionID, state)
   199  			require.Error(t, err)
   200  			require.Contains(t, err.Error(), storage.ErrDataNotFound.Error())
   201  			require.Nil(t, connection)
   202  		}
   203  
   204  		// prepare data in store
   205  		saveInStore(store.store, connectionIDS)
   206  
   207  		// should fail since data doesn't exists in protocol state store
   208  		for _, connectionID := range connectionIDS {
   209  			connection, err := store.GetConnectionRecordAtState(connectionID, state)
   210  			require.Error(t, err)
   211  			require.Contains(t, err.Error(), storage.ErrDataNotFound.Error())
   212  			require.Nil(t, connection)
   213  		}
   214  
   215  		// prepare data in protocol state store
   216  		saveInStore(store.protocolStateStore, connectionIDS)
   217  
   218  		for _, connectionID := range connectionIDS {
   219  			connection, err := store.GetConnectionRecordAtState(connectionID, state)
   220  			require.NoError(t, err)
   221  			require.NotNil(t, connection)
   222  			require.Equal(t, connectionID, connection.ConnectionID)
   223  			require.Equal(t, fmt.Sprintf(threadIDFmt, connectionID), connection.ThreadID)
   224  		}
   225  	})
   226  
   227  	t.Run("get connection record at state - failure", func(t *testing.T) {
   228  		store, err := NewLookup(&mockProvider{})
   229  		require.NoError(t, err)
   230  		require.NotNil(t, store)
   231  
   232  		connection, err := store.GetConnectionRecordAtState("sampleID", "")
   233  		require.Error(t, err)
   234  		require.EqualError(t, err, stateIDEmptyErr)
   235  		require.Nil(t, connection)
   236  	})
   237  }
   238  
   239  func TestConnectionReader_GetConnectionRecordByNSThreadID(t *testing.T) {
   240  	const noOfItems = 12
   241  	nsThreadIDs := make([]string, noOfItems)
   242  
   243  	for i := 0; i < noOfItems; i++ {
   244  		nsThreadIDs[i] = fmt.Sprintf(threadIDFmt, i)
   245  	}
   246  
   247  	saveInStore := func(store storage.Store, ids []string, skipConnection bool) {
   248  		for _, id := range ids {
   249  			connID := fmt.Sprintf(connIDFmt, id)
   250  			connRecBytes, err := json.Marshal(&Record{
   251  				ConnectionID: id,
   252  				ThreadID:     id,
   253  			})
   254  			require.NoError(t, err)
   255  			err = store.Put(id, []byte(connID))
   256  			require.NoError(t, err)
   257  
   258  			if !skipConnection {
   259  				err = store.Put(getConnectionKeyPrefix()(connID), connRecBytes)
   260  				require.NoError(t, err)
   261  			}
   262  		}
   263  	}
   264  
   265  	t.Run("get connection record by NS thread ID", func(t *testing.T) {
   266  		store, err := NewLookup(&mockProvider{})
   267  		require.NoError(t, err)
   268  		require.NotNil(t, store)
   269  
   270  		// should fail since data doesn't exists
   271  		for _, nsThreadID := range nsThreadIDs {
   272  			connection, err := store.GetConnectionRecordByNSThreadID(nsThreadID)
   273  			require.Error(t, err)
   274  			require.Contains(t, err.Error(), storage.ErrDataNotFound.Error())
   275  			require.Nil(t, connection)
   276  		}
   277  
   278  		// prepare data in store
   279  		saveInStore(store.store, nsThreadIDs, false)
   280  
   281  		// should fail since data doesn't exists in protocol state store
   282  		for _, nsThreadID := range nsThreadIDs {
   283  			connection, err := store.GetConnectionRecordByNSThreadID(nsThreadID)
   284  			require.Error(t, err)
   285  			require.Contains(t, err.Error(), storage.ErrDataNotFound.Error())
   286  			require.Nil(t, connection)
   287  		}
   288  
   289  		// prepare only ns thread data in protocol state store
   290  		// skip connection
   291  		saveInStore(store.protocolStateStore, nsThreadIDs, true)
   292  
   293  		// should fail since data doesn't exists in protocol state store
   294  		for _, nsThreadID := range nsThreadIDs {
   295  			connection, err := store.GetConnectionRecordByNSThreadID(nsThreadID)
   296  			require.Error(t, err)
   297  			require.Contains(t, err.Error(), storage.ErrDataNotFound.Error())
   298  			require.Nil(t, connection)
   299  		}
   300  
   301  		// prepare data in protocol state store
   302  		saveInStore(store.protocolStateStore, nsThreadIDs, false)
   303  
   304  		// should fail since data doesn't exists in protocol state store
   305  		for _, nsThreadID := range nsThreadIDs {
   306  			connection, err := store.GetConnectionRecordByNSThreadID(nsThreadID)
   307  			require.NoError(t, err)
   308  			require.NotNil(t, connection)
   309  			require.Equal(t, nsThreadID, connection.ThreadID)
   310  		}
   311  	})
   312  }
   313  
   314  func TestConnectionRecorder_QueryConnectionRecord(t *testing.T) {
   315  	t.Run("test query connection record", func(t *testing.T) {
   316  		store := &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry)}
   317  
   318  		protocolStateStore, err := mem.NewProvider().OpenStore(Namespace)
   319  		require.NoError(t, err)
   320  
   321  		const (
   322  			storeCount              = 5
   323  			overlap                 = 3
   324  			protocolStateStoreCount = 4
   325  		)
   326  
   327  		for i := 0; i < storeCount+overlap; i++ {
   328  			val, jsonErr := json.Marshal(&Record{
   329  				ConnectionID: fmt.Sprint(i),
   330  			})
   331  			require.NoError(t, jsonErr)
   332  
   333  			err = store.Put(fmt.Sprintf("%s_abc%d", connIDKeyPrefix, i), val, storage.Tag{Name: "conn_"})
   334  			require.NoError(t, err)
   335  		}
   336  		for i := overlap; i < protocolStateStoreCount+storeCount; i++ {
   337  			val, jsonErr := json.Marshal(&Record{
   338  				ConnectionID: fmt.Sprint(i),
   339  			})
   340  			require.NoError(t, jsonErr)
   341  
   342  			err = protocolStateStore.Put(fmt.Sprintf("%s_abc%d", connIDKeyPrefix, i), val, storage.Tag{Name: "conn_"})
   343  			require.NoError(t, err)
   344  		}
   345  
   346  		recorder, err := NewLookup(&mockProvider{store: store, protocolStateStore: protocolStateStore})
   347  		require.NoError(t, err)
   348  		require.NotNil(t, recorder)
   349  		result, err := recorder.QueryConnectionRecords()
   350  		require.NoError(t, err)
   351  		require.Len(t, result, storeCount+protocolStateStoreCount)
   352  	})
   353  
   354  	t.Run("test query connection record failure", func(t *testing.T) {
   355  		store := &mockstorage.MockStore{Store: make(map[string]mockstorage.DBEntry)}
   356  		err := store.Put(fmt.Sprintf("%s_abc123", connIDKeyPrefix), []byte("-----"), storage.Tag{Name: "conn_"})
   357  		require.NoError(t, err)
   358  
   359  		recorder, err := NewLookup(&mockProvider{store: store})
   360  		require.NoError(t, err)
   361  		require.NotNil(t, recorder)
   362  		result, err := recorder.QueryConnectionRecords()
   363  		require.Error(t, err)
   364  		require.Empty(t, result)
   365  	})
   366  
   367  	t.Run("test query connection record failure - protocol state store read", func(t *testing.T) {
   368  		expected := fmt.Errorf("query error")
   369  
   370  		recorder, err := NewRecorder(&mockProvider{
   371  			protocolStateStore: &mockstorage.MockStore{ErrQuery: expected},
   372  		})
   373  		require.NoError(t, err)
   374  		require.NotNil(t, recorder)
   375  
   376  		result, err := recorder.QueryConnectionRecords()
   377  		require.Error(t, err)
   378  		require.Empty(t, result)
   379  		require.ErrorIs(t, err, expected)
   380  	})
   381  }
   382  
   383  func TestGetConnectionIDByDIDs(t *testing.T) {
   384  	myDID := "did:mydid:123"
   385  	theirDID := "did:theirdid:789"
   386  
   387  	t.Run("get connection record by did - success", func(t *testing.T) {
   388  		recorder, err := NewRecorder(&mockProvider{})
   389  		require.NoError(t, err)
   390  
   391  		require.NotNil(t, recorder)
   392  		connRec := &Record{
   393  			ThreadID:     threadIDValue,
   394  			ConnectionID: sampleConnID,
   395  			State:        StateNameCompleted,
   396  			Namespace:    MyNSPrefix,
   397  			MyDID:        myDID,
   398  			TheirDID:     theirDID,
   399  		}
   400  		err = recorder.SaveConnectionRecord(connRec)
   401  		require.NoError(t, err)
   402  
   403  		connectionID, err := recorder.GetConnectionIDByDIDs(myDID, theirDID)
   404  		require.NoError(t, err)
   405  		require.Equal(t, sampleConnID, connectionID)
   406  
   407  		connectionRecord, err := recorder.GetConnectionRecordByDIDs(myDID, theirDID)
   408  		require.NoError(t, err)
   409  		require.Equal(t, sampleConnID, connectionRecord.ConnectionID)
   410  
   411  		connectionRecord, err = recorder.GetConnectionRecordByTheirDID(theirDID)
   412  		require.NoError(t, err)
   413  		require.Equal(t, sampleConnID, connectionRecord.ConnectionID)
   414  	})
   415  
   416  	t.Run("get connection record by did - not found", func(t *testing.T) {
   417  		recorder, err := NewRecorder(&mockProvider{})
   418  		require.NoError(t, err)
   419  
   420  		connectionID, err := recorder.GetConnectionIDByDIDs(myDID, theirDID)
   421  		require.Error(t, err)
   422  		require.Contains(t, err.Error(), "get connection record by DIDs")
   423  		require.Empty(t, connectionID)
   424  
   425  		connectionRecord, err := recorder.GetConnectionRecordByDIDs(myDID, theirDID)
   426  		require.Error(t, err)
   427  		require.ErrorIs(t, err, storage.ErrDataNotFound)
   428  		require.Nil(t, connectionRecord)
   429  	})
   430  
   431  	t.Run("get connection record by did - store query error", func(t *testing.T) {
   432  		expected := fmt.Errorf("query error")
   433  
   434  		recorder, err := NewRecorder(&mockProvider{
   435  			store: &mockstorage.MockStore{ErrQuery: expected},
   436  		})
   437  		require.NoError(t, err)
   438  
   439  		connectionRecord, err := recorder.GetConnectionRecordByDIDs(myDID, theirDID)
   440  		require.Error(t, err)
   441  		require.ErrorIs(t, err, expected)
   442  		require.Nil(t, connectionRecord)
   443  	})
   444  }
   445  
   446  // mockProvider for connection recorder.
   447  type mockProvider struct {
   448  	protocolStateStoreError error
   449  	storeError              error
   450  	protocolStoreConfError  error
   451  	storeConfError          error
   452  	store                   storage.Store
   453  	protocolStateStore      storage.Store
   454  }
   455  
   456  // ProtocolStateStorageProvider is mock protocol state storage provider for connection recorder.
   457  func (p *mockProvider) ProtocolStateStorageProvider() storage.Provider {
   458  	return mockStorageProvider(p.protocolStateStore, p.protocolStateStoreError, p.protocolStoreConfError)
   459  }
   460  
   461  // StorageProvider is mock storage provider for connection recorder.
   462  func (p *mockProvider) StorageProvider() storage.Provider {
   463  	return mockStorageProvider(p.store, p.storeError, p.storeConfError)
   464  }
   465  
   466  func mockStorageProvider(store storage.Store, errOpen, errConfig error) storage.Provider {
   467  	if errOpen != nil {
   468  		return &mockstorage.MockStoreProvider{ErrOpenStoreHandle: errOpen}
   469  	}
   470  
   471  	var m *mockstorage.MockStoreProvider
   472  
   473  	if store != nil {
   474  		m = mockstorage.NewCustomMockStoreProvider(store)
   475  	} else {
   476  		m = mockstorage.NewMockStoreProvider()
   477  	}
   478  
   479  	m.ErrSetStoreConfig = errConfig
   480  
   481  	return m
   482  }