github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/dispatcher/inbound/inbound_message_handler_test.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package inbound
     8  
     9  import (
    10  	"fmt"
    11  	"testing"
    12  
    13  	"github.com/btcsuite/btcutil/base58"
    14  	"github.com/golang/mock/gomock"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/middleware"
    18  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    19  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange"
    20  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    21  	"github.com/hyperledger/aries-framework-go/pkg/doc/did"
    22  	mocks "github.com/hyperledger/aries-framework-go/pkg/internal/gomocks/didcomm/common/service"
    23  	"github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/msghandler"
    24  	mockdidexchange "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/didexchange"
    25  	"github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/generic"
    26  	mockprovider "github.com/hyperledger/aries-framework-go/pkg/mock/provider"
    27  	mockstore "github.com/hyperledger/aries-framework-go/pkg/mock/storage"
    28  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    29  	didstore "github.com/hyperledger/aries-framework-go/pkg/store/did"
    30  )
    31  
    32  func TestNewInboundMessageHandler(t *testing.T) {
    33  	t.Run("success", func(t *testing.T) {
    34  		_ = NewInboundMessageHandler(emptyProvider())
    35  	})
    36  }
    37  
    38  func TestMessageHandler_HandlerFunc(t *testing.T) {
    39  	handler := NewInboundMessageHandler(emptyProvider())
    40  
    41  	handleFunc := handler.HandlerFunc()
    42  
    43  	err := handleFunc(&transport.Envelope{
    44  		Message: []byte(`{
    45  	"@id":"12345",
    46  	"@type":"message-type"
    47  }`),
    48  	})
    49  	require.NoError(t, err)
    50  }
    51  
    52  func TestMessageHandler_HandleInboundEnvelope(t *testing.T) {
    53  	ctrl := gomock.NewController(t)
    54  	defer ctrl.Finish()
    55  
    56  	testCases := []struct {
    57  		testName        string
    58  		svcAccept       string
    59  		svcName         string
    60  		svcHandleErr    error
    61  		getDIDsErr      error
    62  		msgSvcAccept    bool
    63  		msgSvcHandleErr error
    64  		messengerErr    error
    65  		message         string
    66  		expectErr       string
    67  	}{
    68  		{
    69  			testName:  "success: without getDIDs",
    70  			svcAccept: "message-type",
    71  			svcName:   didexchange.DIDExchange,
    72  			message: `{
    73  	"@id":"12345",
    74  	"@type":"message-type"
    75  }`,
    76  		},
    77  		{
    78  			testName:  "success: with getDIDs",
    79  			svcAccept: "message-type",
    80  			svcName:   "service-name",
    81  			message: `{
    82  	"@id":"12345",
    83  	"@type":"message-type"
    84  }`,
    85  		},
    86  		{
    87  			testName:  "success: didcomm v2",
    88  			svcAccept: "message-type",
    89  			message: `{
    90  	"id":"12345",
    91  	"type":"message-type",
    92  	"body":{}
    93  }`,
    94  		},
    95  		{
    96  			testName:  "fail: parsing message",
    97  			message:   `{`,
    98  			expectErr: "invalid payload data format",
    99  		},
   100  		{
   101  			testName: "fail: can't determine if didcomm v1 or v2",
   102  			message: `{
   103  	"body":{},
   104  	"~thread":"12345"
   105  }`,
   106  			expectErr: "not a valid didcomm v1 or v2 message",
   107  		},
   108  		{
   109  			testName:   "fail: getDIDs error",
   110  			svcAccept:  "message-type",
   111  			svcName:    "service-name",
   112  			getDIDsErr: fmt.Errorf("get DIDs error"),
   113  			message: `{
   114  	"@id":"12345",
   115  	"@type":"message-type"
   116  }`,
   117  			expectErr: "get DIDs error",
   118  		},
   119  		{
   120  			testName:   "fail: getDIDs error for didcomm v2",
   121  			svcAccept:  "message-type",
   122  			svcName:    "service-name",
   123  			getDIDsErr: fmt.Errorf("get DIDs error"),
   124  			message: `{
   125  	"id":"12345",
   126  	"type":"message-type",
   127  	"body":{}
   128  }`,
   129  			expectErr: "get DIDs error",
   130  		},
   131  		{
   132  			testName:  "fail: didcomm v2 did rotation error",
   133  			svcAccept: "message-type",
   134  			message: `{
   135  	"id":"12345",
   136  	"type":"message-type",
   137  	"from_prior":{},
   138  	"body":{}
   139  }`,
   140  			expectErr: "field should be a string",
   141  		},
   142  		{
   143  			testName:  "success: messenger service",
   144  			svcAccept: "message-type",
   145  			message: `{
   146  	"@id":"12345",
   147  	"@type":"different-type"
   148  }`,
   149  			msgSvcAccept: true,
   150  		},
   151  		{
   152  			testName:  "fail: bad message purpose field",
   153  			svcAccept: "message-type",
   154  			message: `{
   155  	"@id":"12345",
   156  	"@type":"different-type",
   157  	"~purpose": {"aaaaa":"bbbbb"}
   158  }`,
   159  			expectErr: "expected type 'string'",
   160  		},
   161  		{
   162  			testName:  "fail: error in getDIDs for message service",
   163  			svcAccept: "message-type",
   164  			message: `{
   165  	"@id":"12345",
   166  	"@type":"different-type"
   167  }`,
   168  			msgSvcAccept: true,
   169  			getDIDsErr:   fmt.Errorf("get DIDs error"),
   170  			expectErr:    "get DIDs error",
   171  		},
   172  		{
   173  			testName:  "fail: error in messenger HandleInbound",
   174  			svcAccept: "message-type",
   175  			message: `{
   176  	"@id":"12345",
   177  	"@type":"different-type"
   178  }`,
   179  			msgSvcAccept: true,
   180  			messengerErr: fmt.Errorf("messenger error"),
   181  			expectErr:    "messenger error",
   182  		},
   183  		{
   184  			testName:  "fail: error in message service HandleInbound",
   185  			svcAccept: "message-type",
   186  			message: `{
   187  	"@id":"12345",
   188  	"@type":"different-type"
   189  }`,
   190  			msgSvcAccept:    true,
   191  			msgSvcHandleErr: fmt.Errorf("message svc error"),
   192  			expectErr:       "message svc error",
   193  		},
   194  		{
   195  			testName:  "fail: no handler for given message",
   196  			svcAccept: "message-type",
   197  			message: `{
   198  	"@id":"12345",
   199  	"@type":"different-type"
   200  }`,
   201  			expectErr: "no message handlers found",
   202  		},
   203  		{
   204  			testName:  "fail: no handler for given didcomm v2 message",
   205  			svcAccept: "message-type",
   206  			message: `{
   207  	"id":"12345",
   208  	"type":"different-type",
   209  	"body":{}
   210  }`,
   211  			msgSvcAccept: true,
   212  			expectErr:    "no message handlers found",
   213  		},
   214  	}
   215  
   216  	store := mockstore.NewMockStoreProvider()
   217  	psStore := mockstore.NewMockStoreProvider()
   218  
   219  	p := mockprovider.Provider{
   220  		StorageProviderValue:              store,
   221  		ProtocolStateStorageProviderValue: psStore,
   222  	}
   223  
   224  	connectionRecorder, err := connection.NewRecorder(&p)
   225  	require.NoError(t, err)
   226  
   227  	myDID := "did:test:my-did"
   228  	theirDID := "did:test:their-did"
   229  
   230  	err = connectionRecorder.SaveConnectionRecord(&connection.Record{
   231  		ConnectionID:  "12345",
   232  		MyDID:         myDID,
   233  		TheirDID:      theirDID,
   234  		State:         connection.StateNameCompleted,
   235  		MyDIDRotation: nil,
   236  	})
   237  	require.NoError(t, err)
   238  
   239  	didRotator, err := middleware.New(&p)
   240  	require.NoError(t, err)
   241  
   242  	t.Parallel()
   243  
   244  	for _, tc := range testCases {
   245  		t.Run(tc.testName, func(t *testing.T) {
   246  			msgSvcProvider := msghandler.MockMsgSvcProvider{}
   247  
   248  			require.NoError(t, msgSvcProvider.Register(&generic.MockMessageSvc{
   249  				AcceptFunc: func(msgType string, purpose []string) bool {
   250  					return tc.msgSvcAccept
   251  				},
   252  				HandleFunc: func(msg *service.DIDCommMsg) (string, error) {
   253  					return "", tc.msgSvcHandleErr
   254  				},
   255  			}))
   256  
   257  			messengerHandler := mocks.NewMockMessengerHandler(ctrl)
   258  			messengerHandler.EXPECT().HandleInbound(gomock.Any(), gomock.Any()).AnyTimes().Return(tc.messengerErr)
   259  
   260  			didex := mockdidexchange.MockDIDExchangeSvc{
   261  				ProtocolName: tc.svcName,
   262  				AcceptFunc: func(s string) bool {
   263  					return s == tc.svcAccept
   264  				},
   265  				HandleFunc: func(msg service.DIDCommMsg) (string, error) {
   266  					return "", tc.svcHandleErr
   267  				},
   268  			}
   269  
   270  			prov := mockprovider.Provider{
   271  				DIDConnectionStoreValue: &mockDIDStore{getDIDErr: tc.getDIDsErr, results: map[string]mockDIDResult{
   272  					base58.Encode([]byte("my_key")):    {did: myDID},
   273  					base58.Encode([]byte("their_key")): {did: theirDID},
   274  				}},
   275  				MessageServiceProviderValue: &msgSvcProvider,
   276  				InboundMessengerValue:       messengerHandler,
   277  				ServiceValue:                &didex,
   278  				DIDRotatorValue:             *didRotator,
   279  			}
   280  
   281  			h := NewInboundMessageHandler(&prov)
   282  
   283  			err = h.HandleInboundEnvelope(&transport.Envelope{
   284  				Message: []byte(tc.message),
   285  				ToKey:   []byte("my_key"),
   286  				FromKey: []byte("their_key"),
   287  			})
   288  			if tc.expectErr == "" {
   289  				require.NoError(t, err)
   290  			} else {
   291  				require.Error(t, err)
   292  				require.Contains(t, err.Error(), tc.expectErr)
   293  			}
   294  		})
   295  	}
   296  }
   297  
   298  func TestMessageHandler_Initialize(t *testing.T) {
   299  	p := emptyProvider()
   300  
   301  	// second Initialize is no-op
   302  	h := &MessageHandler{}
   303  	h.Initialize(p)
   304  	h.Initialize(p)
   305  
   306  	// first Initialize is in New, second is no-op
   307  	h = NewInboundMessageHandler(p)
   308  	h.Initialize(p)
   309  }
   310  
   311  func TestMessageHandler_getDIDs(t *testing.T) {
   312  	t.Run("success", func(t *testing.T) {
   313  		p := emptyProvider()
   314  
   315  		h := NewInboundMessageHandler(p)
   316  
   317  		myDID, theirDID, err := h.getDIDs(&transport.Envelope{
   318  			ToKey:   []byte("abcd"),
   319  			FromKey: []byte("abcd"),
   320  		}, nil)
   321  
   322  		require.NoError(t, err)
   323  		require.Equal(t, "", myDID)
   324  		require.Equal(t, "", theirDID)
   325  	})
   326  
   327  	t.Run("success: dids from key refs", func(t *testing.T) {
   328  		p := emptyProvider()
   329  
   330  		h := NewInboundMessageHandler(p)
   331  
   332  		myDID, theirDID, err := h.getDIDs(&transport.Envelope{
   333  			ToKey:   []byte(`{"kid":"did:peer:alice#key-1"}`),
   334  			FromKey: []byte(`{"kid":"did:peer:bob#key-1"}`),
   335  		}, nil)
   336  
   337  		require.NoError(t, err)
   338  		require.Equal(t, "did:peer:alice", myDID)
   339  		require.Equal(t, "did:peer:bob", theirDID)
   340  	})
   341  
   342  	t.Run("success: their DID from message", func(t *testing.T) {
   343  		p := emptyProvider()
   344  
   345  		h := NewInboundMessageHandler(p)
   346  
   347  		myDID, theirDID, err := h.getDIDs(&transport.Envelope{
   348  			ToKey:   []byte(`{"kid":"did:peer:alice#key-1"}`),
   349  			FromKey: nil,
   350  		}, service.DIDCommMsgMap{
   351  			"from": "did:peer:bob",
   352  		})
   353  
   354  		require.NoError(t, err)
   355  		require.Equal(t, "did:peer:alice", myDID)
   356  		require.Equal(t, "did:peer:bob", theirDID)
   357  	})
   358  
   359  	t.Run("fail: bad did key", func(t *testing.T) {
   360  		p := emptyProvider()
   361  
   362  		h := NewInboundMessageHandler(p)
   363  
   364  		_, _, err := h.getDIDs(&transport.Envelope{
   365  			ToKey: []byte(`abcd # abcd "kid":"did:`), // matches string matching, but is not JSON
   366  		}, nil)
   367  
   368  		require.Error(t, err)
   369  		require.Contains(t, err.Error(), "pubKeyToDID: unmarshal key")
   370  
   371  		_, _, err = h.getDIDs(&transport.Envelope{
   372  			FromKey: []byte(`abcd # abcd "kid":"did:`), // matches string matching, but is not JSON
   373  		}, nil)
   374  
   375  		require.Error(t, err)
   376  		require.Contains(t, err.Error(), "pubKeyToDID: unmarshal key")
   377  	})
   378  
   379  	t.Run("fail: can't get my did", func(t *testing.T) {
   380  		p := emptyProvider()
   381  		p.DIDConnectionStoreValue = &mockDIDStore{
   382  			results: map[string]mockDIDResult{
   383  				base58.Encode([]byte("my_key")):    {did: "bbb", err: fmt.Errorf("mock did store error")},
   384  				base58.Encode([]byte("their_key")): {did: "aaa", err: nil},
   385  			},
   386  		}
   387  		p.GetDIDsMaxRetriesValue = 1
   388  
   389  		h := NewInboundMessageHandler(p)
   390  
   391  		_, _, err := h.getDIDs(&transport.Envelope{
   392  			ToKey:   []byte("my_key"),
   393  			FromKey: []byte("their_key"),
   394  		}, nil)
   395  
   396  		require.Error(t, err)
   397  		require.Contains(t, err.Error(), "mock did store error")
   398  	})
   399  
   400  	t.Run("fail: can't get their did", func(t *testing.T) {
   401  		p := emptyProvider()
   402  		p.DIDConnectionStoreValue = &mockDIDStore{
   403  			results: map[string]mockDIDResult{
   404  				base58.Encode([]byte("my_key")):    {did: "aaa", err: nil},
   405  				base58.Encode([]byte("their_key")): {did: "bbb", err: fmt.Errorf("mock did store error")},
   406  			},
   407  		}
   408  		p.GetDIDsMaxRetriesValue = 1
   409  
   410  		h := NewInboundMessageHandler(p)
   411  
   412  		_, _, err := h.getDIDs(&transport.Envelope{
   413  			ToKey:   []byte("my_key"),
   414  			FromKey: []byte("their_key"),
   415  		}, nil)
   416  
   417  		require.Error(t, err)
   418  		require.Contains(t, err.Error(), "mock did store error")
   419  	})
   420  
   421  	t.Run("not found", func(t *testing.T) {
   422  		p := emptyProvider()
   423  		p.DIDConnectionStoreValue = &mockDIDStore{
   424  			getDIDErr: didstore.ErrNotFound,
   425  		}
   426  		p.GetDIDsMaxRetriesValue = 1
   427  
   428  		h := NewInboundMessageHandler(p)
   429  
   430  		myDID, theirDID, err := h.getDIDs(&transport.Envelope{
   431  			ToKey:   []byte("my_key"),
   432  			FromKey: []byte("their_key"),
   433  		}, nil)
   434  
   435  		require.NoError(t, err)
   436  		require.Equal(t, "", myDID)
   437  		require.Equal(t, "", theirDID)
   438  	})
   439  
   440  	t.Run("success: theirDID needs retry", func(t *testing.T) {
   441  		p := emptyProvider()
   442  		p.DIDConnectionStoreValue = &mockDIDStore{
   443  			results: map[string]mockDIDResult{
   444  				base58.Encode([]byte("my_key")):    {did: "aaa"},
   445  				base58.Encode([]byte("their_key")): {did: "bbb"},
   446  			},
   447  			temps: map[string]mockDIDResult{
   448  				base58.Encode([]byte("their_key")): {err: didstore.ErrNotFound},
   449  			},
   450  			countDown: 1,
   451  		}
   452  		p.GetDIDsMaxRetriesValue = 3
   453  
   454  		h := NewInboundMessageHandler(p)
   455  
   456  		myDID, theirDID, err := h.getDIDs(&transport.Envelope{
   457  			ToKey:   []byte("my_key"),
   458  			FromKey: []byte("their_key"),
   459  		}, nil)
   460  
   461  		require.NoError(t, err)
   462  		require.Equal(t, "aaa", myDID)
   463  		require.Equal(t, "bbb", theirDID)
   464  	})
   465  }
   466  
   467  func emptyProvider() *mockprovider.Provider {
   468  	return &mockprovider.Provider{
   469  		DIDConnectionStoreValue:     &mockDIDStore{},
   470  		MessageServiceProviderValue: &msghandler.MockMsgSvcProvider{},
   471  		InboundMessengerValue:       &mocks.MockMessengerHandler{},
   472  		ServiceValue: &mockdidexchange.MockDIDExchangeSvc{
   473  			AcceptFunc: func(_ string) bool {
   474  				return true
   475  			},
   476  			HandleFunc: func(msg service.DIDCommMsg) (string, error) {
   477  				return "", nil
   478  			},
   479  		},
   480  	}
   481  }
   482  
   483  type mockDIDResult struct {
   484  	did string
   485  	err error
   486  }
   487  
   488  type mockDIDStore struct {
   489  	getDIDErr error
   490  	results   map[string]mockDIDResult
   491  	temps     map[string]mockDIDResult
   492  	countDown uint
   493  }
   494  
   495  // GetDID returns DID associated with key.
   496  func (m *mockDIDStore) GetDID(key string) (string, error) {
   497  	if m.getDIDErr != nil {
   498  		return "", m.getDIDErr
   499  	}
   500  
   501  	if m.countDown > 0 {
   502  		m.countDown--
   503  
   504  		// note: fallthrough to trying m.results[key] if temps is missing entry
   505  		if res, ok := m.temps[key]; ok {
   506  			return res.did, res.err
   507  		}
   508  	}
   509  
   510  	if res, ok := m.results[key]; ok {
   511  		return res.did, res.err
   512  	}
   513  
   514  	return "", nil
   515  }
   516  
   517  // SaveDID saves DID to the underlying storage.
   518  func (m *mockDIDStore) SaveDID(string, ...string) error {
   519  	return nil
   520  }
   521  
   522  // SaveDIDFromDoc saves DID from did.Doc to the underlying storage.
   523  func (m *mockDIDStore) SaveDIDFromDoc(*did.Doc) error {
   524  	return nil
   525  }
   526  
   527  // SaveDIDByResolving saves DID resolved by VDR to the underlying storage.
   528  func (m *mockDIDStore) SaveDIDByResolving(string, ...string) error {
   529  	return nil
   530  }