github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/outofband/states_test.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package outofband
     8  
     9  import (
    10  	"errors"
    11  	"testing"
    12  
    13  	"github.com/google/uuid"
    14  	"github.com/stretchr/testify/require"
    15  
    16  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    17  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    18  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange"
    19  	mockdidexchange "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/didexchange"
    20  	mockservice "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/service"
    21  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    22  )
    23  
    24  func TestStateFromName(t *testing.T) {
    25  	t.Run("valid state names", func(t *testing.T) {
    26  		states := []state{
    27  			&stateInitial{},
    28  			&statePrepareResponse{},
    29  			&stateAwaitResponse{},
    30  			&stateDone{},
    31  		}
    32  
    33  		for _, expected := range states {
    34  			actual, err := stateFromName(expected.Name())
    35  			require.NoError(t, err)
    36  			require.Equal(t, expected, actual)
    37  		}
    38  	})
    39  
    40  	t.Run("invalid state name", func(t *testing.T) {
    41  		_, err := stateFromName("invalid")
    42  		require.Error(t, err)
    43  	})
    44  }
    45  
    46  func TestStateInitial_Execute(t *testing.T) {
    47  	t.Run("handles inbound invitation", func(t *testing.T) {
    48  		s := &stateInitial{}
    49  		next, finish, halt, err := s.Execute(&context{Inbound: true}, nil)
    50  		require.NoError(t, err)
    51  		require.IsType(t, &statePrepareResponse{}, next)
    52  		require.NotNil(t, finish)
    53  		require.False(t, halt)
    54  	})
    55  }
    56  
    57  func TestStateAwaitResponse_Execute(t *testing.T) {
    58  	t.Run("error if not an inbound message", func(t *testing.T) {
    59  		s := &stateAwaitResponse{}
    60  
    61  		_, _, _, err := s.Execute(&context{}, nil)
    62  		require.Error(t, err)
    63  		require.Contains(t, err.Error(), "cannot execute")
    64  	})
    65  
    66  	t.Run("handshake-reuse", func(t *testing.T) {
    67  		t.Run("error if cannot fetch connection ID", func(t *testing.T) {
    68  			expected := errors.New("test")
    69  			ctx := &context{
    70  				Inbound: true,
    71  				Action:  Action{Msg: service.NewDIDCommMsgMap(&HandshakeReuse{Type: HandshakeReuseMsgType})},
    72  			}
    73  			deps := &dependencies{
    74  				connections: &mockConnRecorder{
    75  					getConnIDByDIDsErr: expected,
    76  				},
    77  			}
    78  			s := &stateAwaitResponse{}
    79  
    80  			_, _, _, err := s.Execute(ctx, deps)
    81  			require.ErrorIs(t, err, expected)
    82  		})
    83  
    84  		t.Run("error if cannot fetch connection record", func(t *testing.T) {
    85  			expected := errors.New("test")
    86  			ctx := &context{
    87  				Inbound: true,
    88  				Action:  Action{Msg: service.NewDIDCommMsgMap(&HandshakeReuse{Type: HandshakeReuseMsgType})},
    89  			}
    90  			deps := &dependencies{
    91  				connections: &mockConnRecorder{
    92  					getConnRecordErr: expected,
    93  				},
    94  			}
    95  			s := &stateAwaitResponse{}
    96  
    97  			_, _, _, err := s.Execute(ctx, deps)
    98  			require.ErrorIs(t, err, expected)
    99  		})
   100  
   101  		t.Run("error if connection is not in state 'completed'", func(t *testing.T) {
   102  			ctx := &context{
   103  				Inbound: true,
   104  				Action:  Action{Msg: service.NewDIDCommMsgMap(&HandshakeReuse{Type: HandshakeReuseMsgType})},
   105  			}
   106  			deps := &dependencies{
   107  				connections: &mockConnRecorder{
   108  					getConnRecordVal: &connection.Record{State: "initial"},
   109  				},
   110  			}
   111  			s := &stateAwaitResponse{}
   112  
   113  			_, _, _, err := s.Execute(ctx, deps)
   114  			require.Error(t, err)
   115  			require.Contains(t, err.Error(), "unexpected state for connection")
   116  		})
   117  	})
   118  }
   119  
   120  func TestStatePrepareResponse_Execute(t *testing.T) {
   121  	t.Run("new connection", func(t *testing.T) {
   122  		t.Run("error while saving attachment handling state", func(t *testing.T) {
   123  			expected := errors.New("test")
   124  			ctx := &context{Invitation: &Invitation{
   125  				Requests: []*decorator.Attachment{{
   126  					ID: uuid.New().String(),
   127  					Data: decorator.AttachmentData{
   128  						JSON: map[string]interface{}{},
   129  					},
   130  				}},
   131  			}}
   132  			deps := &dependencies{
   133  				connections: nil,
   134  				didSvc:      &mockdidexchange.MockDIDExchangeSvc{},
   135  				saveAttchStateFunc: func(*attachmentHandlingState) error {
   136  					return expected
   137  				},
   138  			}
   139  			s := &statePrepareResponse{}
   140  
   141  			_, _, _, err := s.Execute(ctx, deps)
   142  			require.ErrorIs(t, err, expected)
   143  		})
   144  	})
   145  
   146  	t.Run("connection reuse", func(t *testing.T) {
   147  		t.Run("advances to next state and sends handshake-reuse", func(t *testing.T) {
   148  			savedAttachmentState := false
   149  			ctx := &context{
   150  				Invitation: &Invitation{
   151  					Services: []interface{}{theirDID},
   152  					Requests: []*decorator.Attachment{{
   153  						ID: uuid.New().String(),
   154  						Data: decorator.AttachmentData{
   155  							JSON: map[string]interface{}{},
   156  						},
   157  					}},
   158  				},
   159  				ReuseConnection: theirDID,
   160  			}
   161  			deps := &dependencies{
   162  				connections: &mockConnRecorder{queryConnRecordsVal: []*connection.Record{{
   163  					TheirDID: theirDID,
   164  					State:    didexchange.StateIDCompleted,
   165  				}}},
   166  				saveAttchStateFunc: func(*attachmentHandlingState) error {
   167  					savedAttachmentState = true
   168  
   169  					return nil
   170  				},
   171  			}
   172  			s := &statePrepareResponse{}
   173  
   174  			next, finish, halt, err := s.Execute(ctx, deps)
   175  			require.NoError(t, err)
   176  			require.IsType(t, &stateAwaitResponse{}, next)
   177  			require.True(t, halt)
   178  
   179  			sent := false
   180  
   181  			messenger := &mockservice.MockMessenger{
   182  				ReplyToMsgFunc: func(_ service.DIDCommMsgMap, out service.DIDCommMsgMap, _ string, _ string) error {
   183  					require.Equal(t, HandshakeReuseMsgType, out.Type())
   184  					sent = true
   185  
   186  					return nil
   187  				},
   188  			}
   189  
   190  			err = finish(messenger)
   191  			require.NoError(t, err)
   192  			require.True(t, savedAttachmentState)
   193  			require.True(t, sent)
   194  		})
   195  
   196  		t.Run("error if cannot query connection records", func(t *testing.T) {
   197  			expected := errors.New("test")
   198  			ctx := &context{
   199  				Inbound:            true,
   200  				ReuseAnyConnection: true,
   201  			}
   202  			deps := &dependencies{
   203  				connections: &mockConnRecorder{queryConnRecordsErr: expected},
   204  			}
   205  			s := &statePrepareResponse{}
   206  
   207  			_, _, _, err := s.Execute(ctx, deps)
   208  			require.ErrorIs(t, err, expected)
   209  		})
   210  
   211  		t.Run("error if cannot find matching connection record", func(t *testing.T) {
   212  			ctx := &context{
   213  				Inbound:            true,
   214  				ReuseAnyConnection: true,
   215  				Invitation: &Invitation{
   216  					Services: []interface{}{theirDID},
   217  				},
   218  			}
   219  			deps := &dependencies{
   220  				connections: &mockConnRecorder{},
   221  			}
   222  			s := &statePrepareResponse{}
   223  
   224  			_, _, _, err := s.Execute(ctx, deps)
   225  			require.Error(t, err)
   226  			require.Contains(t, err.Error(), "no existing connection record found for the invitation")
   227  		})
   228  
   229  		t.Run("error when saving attachment handling state", func(t *testing.T) {
   230  			expected := errors.New("test")
   231  			ctx := &context{
   232  				Inbound:         true,
   233  				ReuseConnection: theirDID,
   234  				Invitation: &Invitation{
   235  					Services: []interface{}{theirDID},
   236  					Requests: []*decorator.Attachment{{
   237  						ID: uuid.New().String(),
   238  						Data: decorator.AttachmentData{
   239  							JSON: map[string]interface{}{},
   240  						},
   241  					}},
   242  				},
   243  			}
   244  			deps := &dependencies{
   245  				connections: &mockConnRecorder{queryConnRecordsVal: []*connection.Record{{
   246  					TheirDID: theirDID,
   247  					State:    didexchange.StateIDCompleted,
   248  				}}},
   249  				saveAttchStateFunc: func(*attachmentHandlingState) error {
   250  					return expected
   251  				},
   252  			}
   253  			s := &statePrepareResponse{}
   254  
   255  			_, _, _, err := s.Execute(ctx, deps)
   256  			require.ErrorIs(t, err, expected)
   257  		})
   258  	})
   259  }
   260  
   261  type mockConnRecorder struct {
   262  	saveInvErr          error
   263  	getConnRecordVal    *connection.Record
   264  	getConnRecordErr    error
   265  	getConnIDByDIDsVal  string
   266  	getConnIDByDIDsErr  error
   267  	queryConnRecordsVal []*connection.Record
   268  	queryConnRecordsErr error
   269  }
   270  
   271  func (m *mockConnRecorder) SaveInvitation(string, interface{}) error {
   272  	return m.saveInvErr
   273  }
   274  
   275  func (m *mockConnRecorder) GetConnectionRecord(string) (*connection.Record, error) {
   276  	return m.getConnRecordVal, m.getConnRecordErr
   277  }
   278  
   279  func (m *mockConnRecorder) GetConnectionIDByDIDs(string, string) (string, error) {
   280  	return m.getConnIDByDIDsVal, m.getConnIDByDIDsErr
   281  }
   282  
   283  func (m *mockConnRecorder) QueryConnectionRecords() ([]*connection.Record, error) {
   284  	return m.queryConnRecordsVal, m.queryConnRecordsErr
   285  }