github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/outofband/states_test.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package outofband 8 9 import ( 10 "errors" 11 "testing" 12 13 "github.com/google/uuid" 14 "github.com/stretchr/testify/require" 15 16 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" 17 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" 18 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange" 19 mockdidexchange "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/didexchange" 20 mockservice "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/service" 21 "github.com/hyperledger/aries-framework-go/pkg/store/connection" 22 ) 23 24 func TestStateFromName(t *testing.T) { 25 t.Run("valid state names", func(t *testing.T) { 26 states := []state{ 27 &stateInitial{}, 28 &statePrepareResponse{}, 29 &stateAwaitResponse{}, 30 &stateDone{}, 31 } 32 33 for _, expected := range states { 34 actual, err := stateFromName(expected.Name()) 35 require.NoError(t, err) 36 require.Equal(t, expected, actual) 37 } 38 }) 39 40 t.Run("invalid state name", func(t *testing.T) { 41 _, err := stateFromName("invalid") 42 require.Error(t, err) 43 }) 44 } 45 46 func TestStateInitial_Execute(t *testing.T) { 47 t.Run("handles inbound invitation", func(t *testing.T) { 48 s := &stateInitial{} 49 next, finish, halt, err := s.Execute(&context{Inbound: true}, nil) 50 require.NoError(t, err) 51 require.IsType(t, &statePrepareResponse{}, next) 52 require.NotNil(t, finish) 53 require.False(t, halt) 54 }) 55 } 56 57 func TestStateAwaitResponse_Execute(t *testing.T) { 58 t.Run("error if not an inbound message", func(t *testing.T) { 59 s := &stateAwaitResponse{} 60 61 _, _, _, err := s.Execute(&context{}, nil) 62 require.Error(t, err) 63 require.Contains(t, err.Error(), "cannot execute") 64 }) 65 66 t.Run("handshake-reuse", func(t *testing.T) { 67 t.Run("error if cannot fetch connection ID", func(t *testing.T) { 68 expected := errors.New("test") 69 ctx := &context{ 70 Inbound: true, 71 Action: Action{Msg: service.NewDIDCommMsgMap(&HandshakeReuse{Type: HandshakeReuseMsgType})}, 72 } 73 deps := &dependencies{ 74 connections: &mockConnRecorder{ 75 getConnIDByDIDsErr: expected, 76 }, 77 } 78 s := &stateAwaitResponse{} 79 80 _, _, _, err := s.Execute(ctx, deps) 81 require.ErrorIs(t, err, expected) 82 }) 83 84 t.Run("error if cannot fetch connection record", func(t *testing.T) { 85 expected := errors.New("test") 86 ctx := &context{ 87 Inbound: true, 88 Action: Action{Msg: service.NewDIDCommMsgMap(&HandshakeReuse{Type: HandshakeReuseMsgType})}, 89 } 90 deps := &dependencies{ 91 connections: &mockConnRecorder{ 92 getConnRecordErr: expected, 93 }, 94 } 95 s := &stateAwaitResponse{} 96 97 _, _, _, err := s.Execute(ctx, deps) 98 require.ErrorIs(t, err, expected) 99 }) 100 101 t.Run("error if connection is not in state 'completed'", func(t *testing.T) { 102 ctx := &context{ 103 Inbound: true, 104 Action: Action{Msg: service.NewDIDCommMsgMap(&HandshakeReuse{Type: HandshakeReuseMsgType})}, 105 } 106 deps := &dependencies{ 107 connections: &mockConnRecorder{ 108 getConnRecordVal: &connection.Record{State: "initial"}, 109 }, 110 } 111 s := &stateAwaitResponse{} 112 113 _, _, _, err := s.Execute(ctx, deps) 114 require.Error(t, err) 115 require.Contains(t, err.Error(), "unexpected state for connection") 116 }) 117 }) 118 } 119 120 func TestStatePrepareResponse_Execute(t *testing.T) { 121 t.Run("new connection", func(t *testing.T) { 122 t.Run("error while saving attachment handling state", func(t *testing.T) { 123 expected := errors.New("test") 124 ctx := &context{Invitation: &Invitation{ 125 Requests: []*decorator.Attachment{{ 126 ID: uuid.New().String(), 127 Data: decorator.AttachmentData{ 128 JSON: map[string]interface{}{}, 129 }, 130 }}, 131 }} 132 deps := &dependencies{ 133 connections: nil, 134 didSvc: &mockdidexchange.MockDIDExchangeSvc{}, 135 saveAttchStateFunc: func(*attachmentHandlingState) error { 136 return expected 137 }, 138 } 139 s := &statePrepareResponse{} 140 141 _, _, _, err := s.Execute(ctx, deps) 142 require.ErrorIs(t, err, expected) 143 }) 144 }) 145 146 t.Run("connection reuse", func(t *testing.T) { 147 t.Run("advances to next state and sends handshake-reuse", func(t *testing.T) { 148 savedAttachmentState := false 149 ctx := &context{ 150 Invitation: &Invitation{ 151 Services: []interface{}{theirDID}, 152 Requests: []*decorator.Attachment{{ 153 ID: uuid.New().String(), 154 Data: decorator.AttachmentData{ 155 JSON: map[string]interface{}{}, 156 }, 157 }}, 158 }, 159 ReuseConnection: theirDID, 160 } 161 deps := &dependencies{ 162 connections: &mockConnRecorder{queryConnRecordsVal: []*connection.Record{{ 163 TheirDID: theirDID, 164 State: didexchange.StateIDCompleted, 165 }}}, 166 saveAttchStateFunc: func(*attachmentHandlingState) error { 167 savedAttachmentState = true 168 169 return nil 170 }, 171 } 172 s := &statePrepareResponse{} 173 174 next, finish, halt, err := s.Execute(ctx, deps) 175 require.NoError(t, err) 176 require.IsType(t, &stateAwaitResponse{}, next) 177 require.True(t, halt) 178 179 sent := false 180 181 messenger := &mockservice.MockMessenger{ 182 ReplyToMsgFunc: func(_ service.DIDCommMsgMap, out service.DIDCommMsgMap, _ string, _ string) error { 183 require.Equal(t, HandshakeReuseMsgType, out.Type()) 184 sent = true 185 186 return nil 187 }, 188 } 189 190 err = finish(messenger) 191 require.NoError(t, err) 192 require.True(t, savedAttachmentState) 193 require.True(t, sent) 194 }) 195 196 t.Run("error if cannot query connection records", func(t *testing.T) { 197 expected := errors.New("test") 198 ctx := &context{ 199 Inbound: true, 200 ReuseAnyConnection: true, 201 } 202 deps := &dependencies{ 203 connections: &mockConnRecorder{queryConnRecordsErr: expected}, 204 } 205 s := &statePrepareResponse{} 206 207 _, _, _, err := s.Execute(ctx, deps) 208 require.ErrorIs(t, err, expected) 209 }) 210 211 t.Run("error if cannot find matching connection record", func(t *testing.T) { 212 ctx := &context{ 213 Inbound: true, 214 ReuseAnyConnection: true, 215 Invitation: &Invitation{ 216 Services: []interface{}{theirDID}, 217 }, 218 } 219 deps := &dependencies{ 220 connections: &mockConnRecorder{}, 221 } 222 s := &statePrepareResponse{} 223 224 _, _, _, err := s.Execute(ctx, deps) 225 require.Error(t, err) 226 require.Contains(t, err.Error(), "no existing connection record found for the invitation") 227 }) 228 229 t.Run("error when saving attachment handling state", func(t *testing.T) { 230 expected := errors.New("test") 231 ctx := &context{ 232 Inbound: true, 233 ReuseConnection: theirDID, 234 Invitation: &Invitation{ 235 Services: []interface{}{theirDID}, 236 Requests: []*decorator.Attachment{{ 237 ID: uuid.New().String(), 238 Data: decorator.AttachmentData{ 239 JSON: map[string]interface{}{}, 240 }, 241 }}, 242 }, 243 } 244 deps := &dependencies{ 245 connections: &mockConnRecorder{queryConnRecordsVal: []*connection.Record{{ 246 TheirDID: theirDID, 247 State: didexchange.StateIDCompleted, 248 }}}, 249 saveAttchStateFunc: func(*attachmentHandlingState) error { 250 return expected 251 }, 252 } 253 s := &statePrepareResponse{} 254 255 _, _, _, err := s.Execute(ctx, deps) 256 require.ErrorIs(t, err, expected) 257 }) 258 }) 259 } 260 261 type mockConnRecorder struct { 262 saveInvErr error 263 getConnRecordVal *connection.Record 264 getConnRecordErr error 265 getConnIDByDIDsVal string 266 getConnIDByDIDsErr error 267 queryConnRecordsVal []*connection.Record 268 queryConnRecordsErr error 269 } 270 271 func (m *mockConnRecorder) SaveInvitation(string, interface{}) error { 272 return m.saveInvErr 273 } 274 275 func (m *mockConnRecorder) GetConnectionRecord(string) (*connection.Record, error) { 276 return m.getConnRecordVal, m.getConnRecordErr 277 } 278 279 func (m *mockConnRecorder) GetConnectionIDByDIDs(string, string) (string, error) { 280 return m.getConnIDByDIDsVal, m.getConnIDByDIDsErr 281 } 282 283 func (m *mockConnRecorder) QueryConnectionRecords() ([]*connection.Record, error) { 284 return m.queryConnRecordsVal, m.queryConnRecordsErr 285 }