github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/outofband/service_test.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package outofband 8 9 import ( 10 "encoding/json" 11 "errors" 12 "fmt" 13 "strings" 14 "testing" 15 "time" 16 17 "github.com/google/uuid" 18 "github.com/stretchr/testify/require" 19 20 commonmodel "github.com/hyperledger/aries-framework-go/pkg/common/model" 21 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model" 22 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" 23 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" 24 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange" 25 "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" 26 "github.com/hyperledger/aries-framework-go/pkg/doc/did" 27 "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api" 28 "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol" 29 mockdidexchange "github.com/hyperledger/aries-framework-go/pkg/mock/didcomm/protocol/didexchange" 30 mockstore "github.com/hyperledger/aries-framework-go/pkg/mock/storage" 31 "github.com/hyperledger/aries-framework-go/pkg/store/connection" 32 "github.com/hyperledger/aries-framework-go/spi/storage" 33 ) 34 35 const ( 36 myDID = "did:example:mine" 37 theirDID = "did:example:theirs" 38 ) 39 40 func TestNew(t *testing.T) { 41 t.Run("returns the service", func(t *testing.T) { 42 s, err := New(testProvider()) 43 require.NoError(t, err) 44 require.NotNil(t, s) 45 }) 46 t.Run("fails if no didexchange service is registered", func(t *testing.T) { 47 provider := testProvider() 48 provider.ServiceErr = api.ErrSvcNotFound 49 _, err := New(provider) 50 require.Error(t, err) 51 }) 52 t.Run("fails if the didexchange service cannot be cast to an inboundhandler", func(t *testing.T) { 53 provider := testProvider() 54 provider.ServiceMap[didexchange.DIDExchange] = &struct{}{} 55 _, err := New(provider) 56 require.Error(t, err) 57 }) 58 t.Run("wraps error thrown from protocol state store when it cannot be opened", func(t *testing.T) { 59 expected := errors.New("test") 60 provider := testProvider() 61 provider.ProtocolStateStoreProvider = &mockstore.MockStoreProvider{ 62 ErrOpenStoreHandle: expected, 63 } 64 _, err := New(provider) 65 require.Error(t, err) 66 require.True(t, errors.Is(err, expected)) 67 }) 68 t.Run("wraps error thrown from persistent store when it cannot be opened", func(t *testing.T) { 69 expected := errors.New("test") 70 provider := testProvider() 71 provider.StoreProvider = &mockstore.MockStoreProvider{ 72 ErrOpenStoreHandle: expected, 73 } 74 _, err := New(provider) 75 require.Error(t, err) 76 require.True(t, errors.Is(err, expected)) 77 }) 78 t.Run("fails if the didexchange service cannot be cast to service.Event", func(t *testing.T) { 79 provider := testProvider() 80 provider.ServiceMap[didexchange.DIDExchange] = &struct{ service.InboundHandler }{} 81 _, err := New(provider) 82 require.Error(t, err) 83 }) 84 t.Run("wraps error thrown when attempting to register to listen for didexchange events", func(t *testing.T) { 85 expected := errors.New("test") 86 provider := testProvider() 87 provider.ServiceMap = map[string]interface{}{ 88 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{ 89 RegisterMsgEventErr: expected, 90 }, 91 } 92 _, err := New(provider) 93 require.Error(t, err) 94 require.True(t, errors.Is(err, expected)) 95 }) 96 } 97 98 func TestService_Initialize(t *testing.T) { 99 t.Run("success", func(t *testing.T) { 100 prov := testProvider() 101 svc := Service{} 102 103 err := svc.Initialize(prov) 104 require.NoError(t, err) 105 106 // second init is no-op 107 err = svc.Initialize(prov) 108 require.NoError(t, err) 109 }) 110 111 t.Run("failure, not given a valid provider", func(t *testing.T) { 112 svc := Service{} 113 114 err := svc.Initialize("not a provider") 115 require.Error(t, err) 116 require.Contains(t, err.Error(), "expected provider of type") 117 }) 118 } 119 120 func TestName(t *testing.T) { 121 s, err := New(testProvider()) 122 require.NoError(t, err) 123 require.Equal(t, s.Name(), "out-of-band") 124 } 125 126 func TestAccept(t *testing.T) { 127 t.Run("accepts out-of-band invitation messages", func(t *testing.T) { 128 s, err := New(testProvider()) 129 require.NoError(t, err) 130 require.True(t, s.Accept("https://didcomm.org/out-of-band/1.0/invitation")) 131 }) 132 t.Run("rejects unsupported messages", func(t *testing.T) { 133 s, err := New(testProvider()) 134 require.NoError(t, err) 135 require.False(t, s.Accept("unsupported")) 136 }) 137 } 138 139 func TestHandleInbound(t *testing.T) { 140 t.Run("accepts out-of-band invitation messages", func(t *testing.T) { 141 s := newAutoService(t, testProvider()) 142 _, err := s.HandleInbound(service.NewDIDCommMsgMap(newInvitation()), service.NewDIDCommContext(myDID, theirDID, nil)) 143 require.NoError(t, err) 144 }) 145 146 t.Run("accepts out-of-band invitation messages with service as map[string]interface{} and serviceEndpoint as "+ 147 "string (DIDCommV1)", func(t *testing.T) { 148 s := newAutoService(t, testProvider()) 149 customServiceMap := map[string]interface{}{ 150 "recipientKeys": []string{"did:key:123"}, 151 "serviceEndpoint": "http://user.agent.aries.js.example.com:10081", 152 "type": "did-communication", 153 } 154 _, err := s.HandleInbound(service.NewDIDCommMsgMap(newInvitationWithService(customServiceMap)), 155 service.NewDIDCommContext(myDID, theirDID, nil)) 156 require.NoError(t, err) 157 }) 158 159 t.Run("accepts out-of-band invitation messages with service as map[string]interface{} and serviceEndpoint as "+ 160 "list (DIDCommV2)", func(t *testing.T) { 161 s := newAutoService(t, testProvider()) 162 customServiceMap := map[string]interface{}{ 163 "recipientKeys": []string{"did:key:123"}, 164 "serviceEndpoint": []interface{}{ 165 map[string]interface{}{ 166 "accept": []interface{}{ 167 "didcomm/v2", "didcomm/aip2;env=rfc19", "didcomm/aip2;env=rfc587", 168 }, 169 "uri": "https://alice.aries.example.com:8081", 170 }, 171 }, 172 "type": "DIDCommMessaging", 173 } 174 _, err := s.HandleInbound(service.NewDIDCommMsgMap(newInvitationWithService(customServiceMap)), 175 service.NewDIDCommContext(myDID, theirDID, nil)) 176 require.NoError(t, err) 177 }) 178 179 t.Run("rejects unsupported message types", func(t *testing.T) { 180 s, err := New(testProvider()) 181 require.NoError(t, err) 182 req := newInvitation() 183 req.Type = "invalid" 184 _, err = s.HandleInbound(service.NewDIDCommMsgMap(req), service.NewDIDCommContext(myDID, theirDID, nil)) 185 require.Error(t, err) 186 }) 187 t.Run("fires off an action event", func(t *testing.T) { 188 expected := service.NewDIDCommMsgMap(newInvitation()) 189 s, err := New(testProvider()) 190 require.NoError(t, err) 191 events := make(chan service.DIDCommAction) 192 err = s.RegisterActionEvent(events) 193 require.NoError(t, err) 194 _, err = s.HandleInbound(expected, service.NewDIDCommContext(myDID, theirDID, nil)) 195 require.NoError(t, err) 196 select { 197 case e := <-events: 198 require.Equal(t, Name, e.ProtocolName) 199 require.Equal(t, expected, e.Message) 200 require.Nil(t, e.Properties) 201 case <-time.After(1 * time.Second): 202 t.Error("timeout waiting for action event") 203 } 204 }) 205 t.Run("ThreadID not found", func(t *testing.T) { 206 expected := service.NewDIDCommMsgMap(&HandshakeReuseAccepted{ 207 Type: HandshakeReuseAcceptedMsgType, 208 }) 209 s, err := New(testProvider()) 210 require.NoError(t, err) 211 events := make(chan service.DIDCommAction) 212 err = s.RegisterActionEvent(events) 213 require.NoError(t, err) 214 _, err = s.HandleInbound(expected, service.NewDIDCommContext(myDID, theirDID, nil)) 215 require.Error(t, err) 216 require.Contains(t, err.Error(), "threadID not found") 217 }) 218 t.Run("Load context (error)", func(t *testing.T) { 219 expected := service.NewDIDCommMsgMap(newInvitation()) 220 s := &Service{ 221 transientStore: &mockstore.MockStore{ 222 Store: make(map[string]mockstore.DBEntry), 223 ErrPut: fmt.Errorf("db error"), 224 }, 225 } 226 events := make(chan service.DIDCommAction) 227 err := s.RegisterActionEvent(events) 228 require.NoError(t, err) 229 _, err = s.HandleInbound(expected, service.NewDIDCommContext(myDID, theirDID, nil)) 230 require.Error(t, err) 231 require.Contains(t, err.Error(), "unable to load current context") 232 }) 233 t.Run("sends pre-state msg event", func(t *testing.T) { 234 expected := service.NewDIDCommMsgMap(&HandshakeReuseAccepted{ 235 ID: uuid.New().String(), 236 Type: HandshakeReuseAcceptedMsgType, 237 }) 238 provider := testProvider() 239 provider.ProtocolStateStoreProvider = &mockstore.MockStoreProvider{ 240 Store: &mockstore.MockStore{ 241 Store: map[string]mockstore.DBEntry{ 242 fmt.Sprintf(contextKey, expected.ID()): { 243 Value: marshal(t, &context{ 244 CurrentStateName: StateNameAwaitResponse, 245 Inbound: true, 246 Invitation: newInvitation(), 247 Action: Action{ 248 Msg: expected, 249 }, 250 }), 251 }, 252 }, 253 }, 254 } 255 s, err := New(provider) 256 require.NoError(t, err) 257 stateMsgs := make(chan service.StateMsg) 258 err = s.RegisterMsgEvent(stateMsgs) 259 require.NoError(t, err) 260 err = s.RegisterActionEvent(make(chan service.DIDCommAction)) 261 require.NoError(t, err) 262 _, err = s.HandleInbound(expected, service.NewDIDCommContext(myDID, theirDID, nil)) 263 require.NoError(t, err) 264 265 done := false 266 267 for !done { 268 select { 269 case result := <-stateMsgs: 270 if result.Type != service.PreState || result.StateID != StateNameAwaitResponse { 271 continue 272 } 273 274 done = true 275 276 require.Equal(t, Name, result.ProtocolName) 277 require.Equal(t, expected, result.Msg) 278 props, ok := result.Properties.(*eventProps) 279 require.True(t, ok) 280 require.Empty(t, props.ConnectionID()) 281 require.Nil(t, props.Error()) 282 case <-time.After(time.Second): 283 t.Error("timeout waiting for action event") 284 } 285 } 286 }) 287 t.Run("fails if no listeners have been registered for action events", func(t *testing.T) { 288 s, err := New(testProvider()) 289 require.NoError(t, err) 290 _, err = s.HandleInbound(service.NewDIDCommMsgMap(newInvitation()), service.NewDIDCommContext(myDID, theirDID, nil)) 291 require.Error(t, err) 292 }) 293 } 294 295 func TestService_ActionContinue(t *testing.T) { 296 t.Run("Success", func(t *testing.T) { 297 msg := service.NewDIDCommMsgMap(&HandshakeReuse{ 298 ID: uuid.New().String(), 299 Type: HandshakeReuseMsgType, 300 }) 301 provider := testProvider() 302 connID := uuid.New().String() 303 304 // Note: copied from store/connection/connection_lookup.go 305 mockDIDTagFunc := func(dids ...string) string { 306 for i, v := range dids { 307 dids[i] = strings.ReplaceAll(v, ":", "$") 308 } 309 310 return strings.Join(dids, "|") 311 } 312 313 provider.StoreProvider = &mockstore.MockStoreProvider{ 314 Store: &mockstore.MockStore{ 315 Store: map[string]mockstore.DBEntry{ 316 fmt.Sprintf("didconn_%s_%s", myDID, theirDID): { 317 Value: []byte(connID), 318 }, 319 fmt.Sprintf("conn_%s", connID): { 320 Value: marshal(t, &connection.Record{ 321 ConnectionID: connID, 322 State: "completed", 323 }), 324 Tags: []storage.Tag{ 325 { 326 Name: "bothDIDs", 327 Value: mockDIDTagFunc(myDID, theirDID), 328 }, 329 }, 330 }, 331 }, 332 }, 333 } 334 s, err := New(provider) 335 require.NoError(t, err) 336 337 states := make(chan service.StateMsg, 50) 338 actions := make(chan service.DIDCommAction) 339 340 require.NoError(t, err, s.RegisterMsgEvent(states)) 341 require.NoError(t, s.RegisterActionEvent(actions)) 342 _, err = s.HandleInbound(msg, service.NewDIDCommContext(myDID, theirDID, nil)) 343 require.NoError(t, err) 344 345 var remainingActions []Action 346 347 select { 348 case <-actions: 349 remainingActions, err = s.Actions() 350 require.NoError(t, err) 351 require.Equal(t, 1, len(remainingActions)) 352 require.NoError(t, s.ActionContinue(remainingActions[0].PIID, &userOptions{})) 353 case <-time.After(time.Second): 354 t.Error("timeout") 355 } 356 357 done := false 358 359 for !done { 360 select { 361 case s := <-states: 362 if s.Type == service.PostState && s.StateID == StateNameDone { 363 done = true 364 } 365 case <-time.After(5 * time.Second): 366 require.Fail(t, "timeout waiting for state done") 367 } 368 } 369 370 remainingActions, err = s.Actions() 371 require.NoError(t, err) 372 require.Equal(t, 0, len(remainingActions)) 373 }) 374 t.Run("Error", func(t *testing.T) { 375 require.EqualError(t, (&Service{ 376 transientStore: &mockstore.MockStore{ 377 Store: make(map[string]mockstore.DBEntry), 378 ErrGet: fmt.Errorf("db error"), 379 }, 380 }).ActionContinue("piid", nil), "load context: transientStore get: db error") 381 }) 382 } 383 384 func TestService_ActionStop(t *testing.T) { 385 t.Run("Success", func(t *testing.T) { 386 msg := service.NewDIDCommMsgMap(newInvitation()) 387 s, err := New(testProvider()) 388 require.NoError(t, err) 389 390 actions := make(chan service.DIDCommAction) 391 392 require.NoError(t, s.RegisterActionEvent(actions)) 393 s.callbackChannel = make(chan *callback, 2) 394 _, err = s.HandleInbound(msg, service.NewDIDCommContext(myDID, theirDID, nil)) 395 require.NoError(t, err) 396 397 var remainingActions []Action 398 399 select { 400 case <-actions: 401 remainingActions, err = s.Actions() 402 require.NoError(t, err) 403 require.Equal(t, 1, len(remainingActions)) 404 require.NoError(t, s.ActionStop(remainingActions[0].PIID, nil)) 405 case <-time.After(1 * time.Second): 406 t.Error("timeout") 407 } 408 409 remainingActions, err = s.Actions() 410 require.NoError(t, err) 411 require.Equal(t, 0, len(remainingActions)) 412 }) 413 414 t.Run("Error", func(t *testing.T) { 415 require.EqualError(t, (&Service{ 416 transientStore: &mockstore.MockStore{ 417 Store: make(map[string]mockstore.DBEntry), 418 ErrGet: fmt.Errorf("db error"), 419 }, 420 }).ActionStop("piid", nil), "get context: transientStore get: db error") 421 }) 422 } 423 424 func TestServiceStop(t *testing.T) { 425 t.Run("Success", func(t *testing.T) { 426 msg := service.NewDIDCommMsgMap(newInvitation()) 427 s, err := New(testProvider()) 428 require.NoError(t, err) 429 430 actions := make(chan service.DIDCommAction) 431 432 require.NoError(t, s.RegisterActionEvent(actions)) 433 s.callbackChannel = make(chan *callback, 2) 434 _, err = s.HandleInbound(msg, service.NewDIDCommContext(myDID, theirDID, nil)) 435 require.NoError(t, err) 436 437 select { 438 case action := <-actions: 439 action.Stop(nil) 440 case <-time.After(1 * time.Second): 441 t.Error("timeout") 442 } 443 444 remainingActions, err := s.Actions() 445 require.NoError(t, err) 446 require.Equal(t, 0, len(remainingActions)) 447 }) 448 } 449 450 func TestServiceContinue(t *testing.T) { 451 t.Run("enqueues callback", func(t *testing.T) { 452 msg := service.NewDIDCommMsgMap(newInvitation()) 453 s, err := New(testProvider()) 454 require.NoError(t, err) 455 456 actions := make(chan service.DIDCommAction) 457 458 require.NoError(t, s.RegisterActionEvent(actions)) 459 s.callbackChannel = make(chan *callback, 2) 460 _, err = s.HandleInbound(msg, service.NewDIDCommContext(myDID, theirDID, nil)) 461 require.NoError(t, err) 462 463 select { 464 case action := <-actions: 465 action.Continue(nil) 466 case <-time.After(1 * time.Second): 467 t.Error("timeout") 468 } 469 470 select { 471 case c := <-s.callbackChannel: 472 require.Equal(t, msg, c.msg) 473 require.Equal(t, myDID, c.myDID) 474 require.Equal(t, theirDID, c.theirDID) 475 case <-time.After(1 * time.Second): 476 t.Error("timeout") 477 } 478 }) 479 } 480 481 func TestHandleRequestCallback(t *testing.T) { 482 t.Run("invokes the didexchange service", func(t *testing.T) { 483 invoked := make(chan struct{}, 2) 484 c := newCallback() 485 provider := testProvider() 486 provider.ServiceMap = map[string]interface{}{ 487 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{ 488 RespondToFunc: func(*didexchange.OOBInvitation, []string) (string, error) { 489 invoked <- struct{}{} 490 return "", nil 491 }, 492 }, 493 } 494 495 s := newAutoService(t, provider) 496 497 _, err := s.handleInvitationCallback(c) 498 require.NoError(t, err) 499 select { 500 case <-invoked: 501 case <-time.After(time.Second): 502 t.Error("timeout") 503 } 504 }) 505 t.Run("passes a didexchange.OOBInvitation to the didexchange service", func(t *testing.T) { 506 provider := testProvider() 507 provider.ServiceMap = map[string]interface{}{ 508 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{ 509 RespondToFunc: func(i *didexchange.OOBInvitation, _ []string) (string, error) { 510 require.NotNil(t, i) 511 return "", nil 512 }, 513 }, 514 } 515 s := newAutoService(t, provider) 516 _, err := s.handleInvitationCallback(newCallback()) 517 require.NoError(t, err) 518 }) 519 t.Run("wraps error thrown when decoding the message", func(t *testing.T) { 520 expected := errors.New("test") 521 s := newAutoService(t, testProvider()) 522 _, err := s.handleInvitationCallback(&callback{ 523 msg: &testDIDCommMsg{errDecode: expected}, 524 ctx: &context{}, 525 }) 526 require.Error(t, err) 527 require.True(t, errors.Is(err, expected)) 528 }) 529 t.Run("wraps error returned by the didexchange service", func(t *testing.T) { 530 expected := errors.New("test") 531 provider := testProvider() 532 provider.ServiceMap = map[string]interface{}{ 533 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{ 534 RespondToFunc: func(_ *didexchange.OOBInvitation, _ []string) (string, error) { 535 return "", expected 536 }, 537 }, 538 } 539 s := newAutoService(t, provider) 540 _, err := s.handleInvitationCallback(newCallback()) 541 require.Error(t, err) 542 require.True(t, errors.Is(err, expected)) 543 }) 544 t.Run("wraps error returned by the protocol state store", func(t *testing.T) { 545 expected := errors.New("test") 546 provider := testProvider() 547 provider.ProtocolStateStoreProvider = &mockstore.MockStoreProvider{ 548 Store: &mockstore.MockStore{ 549 ErrPut: expected, 550 }, 551 } 552 s := newAutoService(t, provider) 553 _, err := s.handleInvitationCallback(newCallback()) 554 require.Error(t, err) 555 require.True(t, errors.Is(err, expected)) 556 }) 557 } 558 559 func TestHandleDIDEvent(t *testing.T) { 560 t.Run("invokes inbound msg handler", func(t *testing.T) { 561 invoked := make(chan struct{}, 2) 562 connID := uuid.New().String() 563 pthid := uuid.New().String() 564 565 provider := testProvider() 566 provider.InboundDIDCommMsgHandlerFunc = func() service.InboundHandler { 567 return &inboundMsgHandler{handleFunc: func(service.DIDCommMsg, service.DIDCommContext) (string, error) { 568 invoked <- struct{}{} 569 return "", nil 570 }} 571 } 572 573 // setup connection state 574 r, err := connection.NewRecorder(provider) 575 require.NoError(t, err) 576 err = r.SaveConnectionRecord(&connection.Record{ 577 ConnectionID: connID, 578 MyDID: myDID, 579 TheirDID: theirDID, 580 ParentThreadID: pthid, 581 }) 582 require.NoError(t, err) 583 584 s := newAutoService(t, provider, 585 withState(t, &attachmentHandlingState{ 586 ID: pthid, 587 ConnectionID: connID, 588 Invitation: newInvitation(), 589 Done: false, 590 }, 591 )) 592 593 err = s.handleDIDEvent(service.StateMsg{ 594 ProtocolName: didexchange.DIDExchange, 595 Type: service.PostState, 596 Msg: service.NewDIDCommMsgMap(newAck(pthid)), 597 StateID: didexchange.StateIDCompleted, 598 Properties: &mockdidexchange.MockEventProperties{ConnID: connID}, 599 }) 600 require.NoError(t, err) 601 602 select { 603 case <-invoked: 604 case <-time.After(1 * time.Second): 605 t.Error("timeout") 606 } 607 }) 608 t.Run("wraps error returned by the protocol state store", func(t *testing.T) { 609 expected := errors.New("test") 610 const connID = "123" 611 provider := testProvider() 612 provider.ProtocolStateStoreProvider = &mockstore.MockStoreProvider{ 613 Store: &mockstore.MockStore{ 614 Store: make(map[string]mockstore.DBEntry), 615 ErrGet: expected, 616 }, 617 } 618 r, err := connection.NewRecorder(provider) 619 require.NoError(t, err) 620 err = r.SaveConnectionRecord(&connection.Record{ 621 ConnectionID: connID, 622 }) 623 require.NoError(t, err) 624 s := newAutoService(t, provider) 625 err = s.handleDIDEvent(service.StateMsg{ 626 ProtocolName: didexchange.DIDExchange, 627 Type: service.PostState, 628 Msg: service.NewDIDCommMsgMap(newAck()), 629 StateID: didexchange.StateIDCompleted, 630 Properties: &mockdidexchange.MockEventProperties{}, 631 }) 632 require.Error(t, err) 633 require.True(t, errors.Is(err, expected)) 634 }) 635 t.Run("wraps error returned by the persistent store", func(t *testing.T) { 636 expected := errors.New("test") 637 pthid := uuid.New().String() 638 639 provider := testProvider() 640 provider.StoreProvider = &mockstore.MockStoreProvider{ 641 Store: &mockstore.MockStore{ 642 ErrGet: expected, 643 }, 644 } 645 s := newAutoService(t, provider, 646 withState(t, &attachmentHandlingState{ 647 ID: pthid, 648 ConnectionID: uuid.New().String(), 649 Invitation: newInvitation(), 650 Done: false, 651 }, 652 )) 653 err := s.handleDIDEvent(service.StateMsg{ 654 ProtocolName: didexchange.DIDExchange, 655 Type: service.PostState, 656 Msg: service.NewDIDCommMsgMap(newAck(pthid)), 657 StateID: didexchange.StateIDCompleted, 658 Properties: &mockdidexchange.MockEventProperties{}, 659 }) 660 require.Error(t, err) 661 require.True(t, errors.Is(err, expected)) 662 }) 663 t.Run("wraps error thrown by the dispatcher", func(t *testing.T) { 664 expected := errors.New("test") 665 pthid := uuid.New().String() 666 connID := uuid.New().String() 667 668 provider := testProvider() 669 provider.InboundDIDCommMsgHandlerFunc = func() service.InboundHandler { 670 return &inboundMsgHandler{ 671 handleFunc: func(service.DIDCommMsg, service.DIDCommContext) (string, error) { 672 return "", expected 673 }, 674 } 675 } 676 677 // setup connection state 678 r, err := connection.NewRecorder(provider) 679 require.NoError(t, err) 680 err = r.SaveConnectionRecord(&connection.Record{ 681 ConnectionID: connID, 682 MyDID: myDID, 683 TheirDID: theirDID, 684 ParentThreadID: pthid, 685 }) 686 require.NoError(t, err) 687 688 s := newAutoService(t, provider, 689 withState(t, &attachmentHandlingState{ 690 ID: pthid, 691 ConnectionID: connID, 692 Invitation: newInvitation(), 693 Done: false, 694 }, 695 )) 696 err = s.handleDIDEvent(service.StateMsg{ 697 ProtocolName: didexchange.DIDExchange, 698 Type: service.PostState, 699 Msg: service.NewDIDCommMsgMap(newAck(pthid)), 700 StateID: didexchange.StateIDCompleted, 701 Properties: &mockdidexchange.MockEventProperties{ConnID: connID}, 702 }) 703 require.Error(t, err) 704 require.True(t, errors.Is(err, expected)) 705 }) 706 t.Run("wraps error from store when saving state", func(t *testing.T) { 707 expected := errors.New("test") 708 pthid := uuid.New().String() 709 connID := uuid.New().String() 710 711 protocolStateStoreProvider := mockstore.NewMockStoreProvider() 712 provider := &protocol.MockProvider{ 713 StoreProvider: mockstore.NewMockStoreProvider(), 714 ProtocolStateStoreProvider: protocolStateStoreProvider, 715 ServiceMap: map[string]interface{}{ 716 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{}, 717 }, 718 } 719 provider.InboundMsgHandler = func(envelope *transport.Envelope) error { 720 return nil 721 } 722 723 // setup connection state 724 r, err := connection.NewRecorder(provider) 725 require.NoError(t, err) 726 err = r.SaveConnectionRecord(&connection.Record{ 727 ConnectionID: connID, 728 MyDID: myDID, 729 TheirDID: theirDID, 730 ParentThreadID: pthid, 731 }) 732 require.NoError(t, err) 733 734 s := newAutoService(t, provider, 735 withState(t, &attachmentHandlingState{ 736 ID: pthid, 737 ConnectionID: connID, 738 Invitation: newInvitation(), 739 Done: false, 740 }, 741 )) 742 743 s.transientStore = &mockstore.MockStore{ 744 Store: protocolStateStoreProvider.Store.Store, 745 ErrPut: expected, 746 } 747 748 err = s.handleDIDEvent(service.StateMsg{ 749 ProtocolName: didexchange.DIDExchange, 750 Type: service.PostState, 751 Msg: service.NewDIDCommMsgMap(newAck(pthid)), 752 StateID: didexchange.StateIDCompleted, 753 Properties: &mockdidexchange.MockEventProperties{ConnID: connID}, 754 }) 755 require.Error(t, err) 756 require.True(t, errors.Is(err, expected)) 757 }) 758 t.Run("ignores non-poststate did events", func(t *testing.T) { 759 s := newAutoService(t, testProvider()) 760 err := s.handleDIDEvent(service.StateMsg{ 761 ProtocolName: didexchange.DIDExchange, 762 Type: service.PreState, 763 }) 764 require.Error(t, err) 765 require.True(t, errors.Is(err, errIgnoredDidEvent)) 766 }) 767 t.Run("ignores msgs that are not didexchange acks", func(t *testing.T) { 768 s := newAutoService(t, testProvider()) 769 err := s.handleDIDEvent(service.StateMsg{ 770 ProtocolName: didexchange.DIDExchange, 771 Type: service.PostState, 772 Msg: service.NewDIDCommMsgMap(&didexchange.Request{}), 773 }) 774 require.Error(t, err) 775 require.True(t, errors.Is(err, errIgnoredDidEvent)) 776 }) 777 t.Run("ignores acks with no parent thread id", func(t *testing.T) { 778 s := newAutoService(t, testProvider()) 779 err := s.handleDIDEvent(service.StateMsg{ 780 ProtocolName: didexchange.DIDExchange, 781 Type: service.PostState, 782 Msg: service.NewDIDCommMsgMap(&model.Ack{ 783 Type: didexchange.AckMsgType, 784 ID: uuid.New().String(), 785 Status: "great", 786 Thread: &decorator.Thread{ 787 ID: uuid.New().String(), 788 }, 789 }), 790 }) 791 require.Error(t, err) 792 require.True(t, errors.Is(err, errIgnoredDidEvent)) 793 }) 794 t.Run("ignores did event if no more requests are to be dispatched", func(t *testing.T) { 795 pthid := uuid.New().String() 796 connID := uuid.New().String() 797 798 s := newAutoService(t, testProvider(), 799 withState(t, &attachmentHandlingState{ 800 ID: pthid, 801 ConnectionID: connID, 802 Invitation: newInvitation(), 803 Done: false, 804 }), 805 ) 806 s.chooseAttachmentFunc = func(*attachmentHandlingState) (*decorator.Attachment, error) { 807 return nil, nil 808 } 809 err := s.handleDIDEvent(service.StateMsg{ 810 ProtocolName: didexchange.DIDExchange, 811 Type: service.PostState, 812 Msg: service.NewDIDCommMsgMap(newAck(pthid)), 813 }) 814 require.Error(t, err) 815 require.True(t, errors.Is(err, errIgnoredDidEvent)) 816 }) 817 t.Run("wraps error thrown while extracting didcomm msg bytes from request", func(t *testing.T) { 818 expected := errors.New("test") 819 pthid := uuid.New().String() 820 connID := uuid.New().String() 821 822 provider := testProvider() 823 824 // setup connection state 825 r, err := connection.NewRecorder(provider) 826 require.NoError(t, err) 827 err = r.SaveConnectionRecord(&connection.Record{ 828 ConnectionID: connID, 829 MyDID: myDID, 830 TheirDID: theirDID, 831 ParentThreadID: pthid, 832 }) 833 834 require.NoError(t, err) 835 s := newAutoService(t, provider, 836 withState(t, &attachmentHandlingState{ 837 ID: pthid, 838 ConnectionID: connID, 839 Invitation: newInvitation(), 840 Done: false, 841 }), 842 ) 843 s.extractDIDCommMsgBytesFunc = func(*decorator.Attachment) ([]byte, error) { 844 return nil, expected 845 } 846 err = s.handleDIDEvent(service.StateMsg{ 847 ProtocolName: didexchange.DIDExchange, 848 Type: service.PostState, 849 Msg: service.NewDIDCommMsgMap(newAck(pthid)), 850 StateID: didexchange.StateIDCompleted, 851 Properties: &mockdidexchange.MockEventProperties{ConnID: connID}, 852 }) 853 require.Error(t, err) 854 require.True(t, errors.Is(err, expected)) 855 }) 856 } 857 858 func TestListener(t *testing.T) { 859 t.Run("invokes handleReqFunc", func(t *testing.T) { 860 invoked := make(chan struct{}) 861 callbacks := make(chan *callback) 862 handleReqFunc := func(*callback) (string, error) { 863 invoked <- struct{}{} 864 return "", nil 865 } 866 go listener(callbacks, nil, handleReqFunc, nil)() 867 868 callbacks <- &callback{ 869 msg: service.NewDIDCommMsgMap(newInvitation()), 870 } 871 872 select { 873 case <-invoked: 874 case <-time.After(1 * time.Second): 875 t.Error("timeout") 876 } 877 }) 878 t.Run("invokes handleDidEventFunc", func(t *testing.T) { 879 invoked := make(chan struct{}) 880 didEvents := make(chan service.StateMsg) 881 handleDidEventFunc := func(msg service.StateMsg) error { 882 invoked <- struct{}{} 883 return nil 884 } 885 go listener(nil, didEvents, nil, handleDidEventFunc)() 886 didEvents <- service.StateMsg{} 887 888 select { 889 case <-invoked: 890 case <-time.After(1 * time.Second): 891 t.Error("timeout") 892 } 893 }) 894 } 895 896 func TestAcceptInvitation(t *testing.T) { 897 t.Run("returns connectionID", func(t *testing.T) { 898 expected := "123456" 899 provider := testProvider() 900 provider.ServiceMap = map[string]interface{}{ 901 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{ 902 RespondToFunc: func(_ *didexchange.OOBInvitation, _ []string) (string, error) { 903 return expected, nil 904 }, 905 }, 906 } 907 s := newAutoService(t, provider) 908 result, err := s.AcceptInvitation(newInvitation(), &userOptions{}) 909 require.NoError(t, err) 910 require.Equal(t, expected, result) 911 }) 912 t.Run("wraps error from didexchange service", func(t *testing.T) { 913 expected := errors.New("test") 914 provider := testProvider() 915 provider.ServiceMap = map[string]interface{}{ 916 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{ 917 RespondToFunc: func(_ *didexchange.OOBInvitation, _ []string) (string, error) { 918 return "", expected 919 }, 920 }, 921 } 922 s := newAutoService(t, provider) 923 _, err := s.AcceptInvitation(newInvitation(), &userOptions{}) 924 require.Error(t, err) 925 require.True(t, errors.Is(err, expected)) 926 }) 927 t.Run("error if invitation has invalid accept values", func(t *testing.T) { 928 provider := testProvider() 929 s := newAutoService(t, provider) 930 inv := newInvitation() 931 inv.Accept = []string{"INVALID"} 932 _, err := s.AcceptInvitation(inv, &userOptions{}) 933 require.Error(t, err) 934 require.Contains(t, err.Error(), "no acceptable media type profile found in invitation") 935 }) 936 } 937 938 func TestSaveInvitation(t *testing.T) { 939 t.Run("saves invitation", func(t *testing.T) { 940 savedInStore := false 941 savedInDidSvc := false 942 expected := newInvitation() 943 provider := testProvider() 944 provider.StoreProvider = mockstore.NewCustomMockStoreProvider(&stubStore{ 945 putFunc: func(k string, v []byte) error { 946 savedInStore = true 947 result := &Invitation{} 948 err := json.Unmarshal(v, result) 949 require.NoError(t, err) 950 require.Equal(t, expected, result) 951 return nil 952 }, 953 }) 954 provider.ServiceMap[didexchange.DIDExchange] = &mockdidexchange.MockDIDExchangeSvc{ 955 SaveFunc: func(i *didexchange.OOBInvitation) error { 956 savedInDidSvc = true 957 require.NotNil(t, i) 958 require.NotEmpty(t, i.ID) 959 require.Equal(t, expected.ID, i.ThreadID) 960 require.Equal(t, expected.Label, i.TheirLabel) 961 require.Equal(t, expected.Services[0], i.Target) 962 return nil 963 }, 964 } 965 s := newAutoService(t, provider) 966 err := s.SaveInvitation(expected) 967 require.NoError(t, err) 968 require.True(t, savedInStore) 969 require.True(t, savedInDidSvc) 970 }) 971 t.Run("wraps error from store", func(t *testing.T) { 972 expected := errors.New("test") 973 provider := testProvider() 974 provider.StoreProvider = &mockstore.MockStoreProvider{ 975 Store: &mockstore.MockStore{ 976 ErrPut: expected, 977 }, 978 } 979 s := newAutoService(t, provider) 980 err := s.SaveInvitation(newInvitation()) 981 require.Error(t, err) 982 require.True(t, errors.Is(err, expected)) 983 }) 984 t.Run("fails when invitation does not have services", func(t *testing.T) { 985 inv := newInvitation() 986 inv.Services = []interface{}{} 987 s := newAutoService(t, testProvider()) 988 err := s.SaveInvitation(inv) 989 require.Error(t, err) 990 }) 991 t.Run("wraps error from didexchange service", func(t *testing.T) { 992 expected := errors.New("test") 993 provider := testProvider() 994 provider.ServiceMap[didexchange.DIDExchange] = &mockdidexchange.MockDIDExchangeSvc{ 995 SaveFunc: func(*didexchange.OOBInvitation) error { 996 return expected 997 }, 998 } 999 s := newAutoService(t, provider) 1000 err := s.SaveInvitation(newInvitation()) 1001 require.Error(t, err) 1002 require.True(t, errors.Is(err, expected)) 1003 }) 1004 } 1005 1006 func TestChooseTarget(t *testing.T) { 1007 t.Run("chooses a string", func(t *testing.T) { 1008 expected := "abc123" 1009 result, err := chooseTarget([]interface{}{expected}) 1010 require.NoError(t, err) 1011 require.Equal(t, expected, result) 1012 }) 1013 t.Run("chooses a did service entry", func(t *testing.T) { 1014 expected := &did.Service{ 1015 ID: uuid.New().String(), 1016 Type: "did-communication", 1017 Priority: 0, 1018 RecipientKeys: []string{"my ver key"}, 1019 ServiceEndpoint: commonmodel.NewDIDCommV1Endpoint("my service endpoint"), 1020 RoutingKeys: []string{"my routing key"}, 1021 } 1022 result, err := chooseTarget([]interface{}{expected}) 1023 require.NoError(t, err) 1024 require.Equal(t, expected, result) 1025 }) 1026 t.Run("chooses a map-type service", func(t *testing.T) { 1027 expected := map[string]interface{}{ 1028 "id": uuid.New().String(), 1029 "type": "did-communication", 1030 "priority": 0, 1031 "recipientKeys": []string{"my ver key"}, 1032 "serviceEndpoint": commonmodel.NewDIDCommV1Endpoint("my service endpoint"), 1033 "RoutingKeys": []string{"my routing key"}, 1034 } 1035 svc, err := chooseTarget([]interface{}{expected}) 1036 require.NoError(t, err) 1037 result, ok := svc.(*did.Service) 1038 require.True(t, ok) 1039 require.Equal(t, expected["id"], result.ID) 1040 require.Equal(t, expected["type"], result.Type) 1041 require.Equal(t, expected["priority"], result.Priority) 1042 require.Equal(t, expected["recipientKeys"], result.RecipientKeys) 1043 require.Equal(t, expected["serviceEndpoint"], result.ServiceEndpoint) 1044 }) 1045 t.Run("fails if not services are specified", func(t *testing.T) { 1046 _, err := chooseTarget([]interface{}{}) 1047 require.Error(t, err) 1048 }) 1049 } 1050 1051 func testProvider() *protocol.MockProvider { 1052 return &protocol.MockProvider{ 1053 StoreProvider: mockstore.NewMockStoreProvider(), 1054 ProtocolStateStoreProvider: mockstore.NewMockStoreProvider(), 1055 ServiceMap: map[string]interface{}{ 1056 didexchange.DIDExchange: &mockdidexchange.MockDIDExchangeSvc{}, 1057 }, 1058 } 1059 } 1060 1061 func newInvitation() *Invitation { 1062 return newInvitationWithService("did:example:1235") 1063 } 1064 1065 func newInvitationWithService(svc interface{}) *Invitation { 1066 return &Invitation{ 1067 ID: uuid.New().String(), 1068 Type: InvitationMsgType, 1069 Label: "test", 1070 Goal: "test", 1071 GoalCode: "test", 1072 Services: []interface{}{svc}, 1073 Protocols: []string{didexchange.PIURI}, 1074 Requests: []*decorator.Attachment{ 1075 { 1076 ID: uuid.New().String(), 1077 Description: "test", 1078 FileName: "dont_open_this.exe", 1079 MimeType: "text/plain", 1080 Data: decorator.AttachmentData{ 1081 JSON: map[string]interface{}{ 1082 "@id": "123", 1083 "@type": "test-type", 1084 }, 1085 }, 1086 }, 1087 }, 1088 } 1089 } 1090 1091 func newCallback() *callback { 1092 inv := newInvitation() 1093 1094 return &callback{ 1095 myDID: fmt.Sprintf("did:example:%s", uuid.New().String()), 1096 theirDID: fmt.Sprintf("did:example:%s", uuid.New().String()), 1097 msg: service.NewDIDCommMsgMap(inv), 1098 ctx: &context{ 1099 CurrentStateName: StateNameInitial, 1100 Inbound: true, 1101 Invitation: inv, 1102 }, 1103 } 1104 } 1105 1106 func withState(t *testing.T, states ...*attachmentHandlingState) func(*Service) { 1107 return func(s *Service) { 1108 for i := range states { 1109 err := s.save(states[i]) 1110 require.NoError(t, err) 1111 } 1112 } 1113 } 1114 1115 func newAutoService(t *testing.T, 1116 provider *protocol.MockProvider, opts ...func(*Service)) *Service { 1117 s, err := New(provider) 1118 require.NoError(t, err) 1119 1120 for i := range opts { 1121 opts[i](s) 1122 } 1123 1124 events := make(chan service.DIDCommAction) 1125 require.NoError(t, s.RegisterActionEvent(events)) 1126 1127 go service.AutoExecuteActionEvent(events) 1128 1129 return s 1130 } 1131 1132 func newAck(pthid ...string) *model.Ack { 1133 a := &model.Ack{ 1134 Type: didexchange.AckMsgType, 1135 ID: uuid.New().String(), 1136 Status: "good", 1137 Thread: &decorator.Thread{ 1138 ID: uuid.New().String(), 1139 PID: uuid.New().String(), 1140 }, 1141 } 1142 1143 if len(pthid) > 0 { 1144 a.Thread.PID = pthid[0] 1145 } 1146 1147 return a 1148 } 1149 1150 type testDIDCommMsg struct { 1151 msgType string 1152 errDecode error 1153 } 1154 1155 func (t *testDIDCommMsg) ID() string { 1156 panic("implement me") 1157 } 1158 1159 func (t *testDIDCommMsg) SetID(id string, opts ...service.Opt) { 1160 panic("implement me") 1161 } 1162 1163 func (t *testDIDCommMsg) SetThread(tid, pid string, opts ...service.Opt) { 1164 panic("implement me") 1165 } 1166 1167 func (t *testDIDCommMsg) UnsetThread() { 1168 panic("implement me") 1169 } 1170 1171 func (t *testDIDCommMsg) Type() string { 1172 return t.msgType 1173 } 1174 1175 func (t *testDIDCommMsg) ThreadID() (string, error) { 1176 panic("implement me") 1177 } 1178 1179 func (t *testDIDCommMsg) ParentThreadID() string { 1180 panic("implement me") 1181 } 1182 1183 func (t *testDIDCommMsg) Clone() service.DIDCommMsgMap { 1184 panic("implement me") 1185 } 1186 1187 func (t *testDIDCommMsg) Metadata() map[string]interface{} { 1188 panic("implement me") 1189 } 1190 1191 func (t *testDIDCommMsg) Decode(v interface{}) error { 1192 return t.errDecode 1193 } 1194 1195 type stubStore struct { 1196 putFunc func(k string, v []byte) error 1197 } 1198 1199 func (s *stubStore) GetTags(key string) ([]storage.Tag, error) { 1200 panic("implement me") 1201 } 1202 1203 func (s *stubStore) GetBulk(keys ...string) ([][]byte, error) { 1204 panic("implement me") 1205 } 1206 1207 func (s *stubStore) Query(expression string, options ...storage.QueryOption) (storage.Iterator, error) { 1208 panic("implement me") 1209 } 1210 1211 func (s *stubStore) Batch(operations []storage.Operation) error { 1212 panic("implement me") 1213 } 1214 1215 func (s *stubStore) Flush() error { 1216 panic("implement me") 1217 } 1218 1219 func (s *stubStore) Close() error { 1220 panic("implement me") 1221 } 1222 1223 func (s *stubStore) Put(k string, v []byte, tags ...storage.Tag) error { 1224 if s.putFunc != nil { 1225 return s.putFunc(k, v) 1226 } 1227 1228 return nil 1229 } 1230 1231 func (s *stubStore) Get(k string) ([]byte, error) { 1232 panic("implement me") 1233 } 1234 1235 func (s *stubStore) Iterator(start, limit string) storage.Iterator { 1236 panic("implement me") 1237 } 1238 1239 func (s *stubStore) Delete(k string) error { 1240 panic("implement me") 1241 } 1242 1243 type inboundMsgHandler struct { 1244 handleFunc func(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) 1245 } 1246 1247 func (i *inboundMsgHandler) HandleInbound(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) { 1248 return i.handleFunc(msg, ctx) 1249 } 1250 1251 func marshal(t *testing.T, v interface{}) []byte { 1252 t.Helper() 1253 1254 raw, err := json.Marshal(v) 1255 require.NoError(t, err) 1256 1257 return raw 1258 }