github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/common/middleware/middleware_test.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 Copyright Avast Software. All Rights Reserved. 4 5 SPDX-License-Identifier: Apache-2.0 6 */ 7 8 package middleware 9 10 import ( 11 "fmt" 12 "testing" 13 14 "github.com/google/uuid" 15 "github.com/stretchr/testify/require" 16 17 "github.com/hyperledger/aries-framework-go/pkg/crypto" 18 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto" 19 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" 20 "github.com/hyperledger/aries-framework-go/pkg/doc/did" 21 "github.com/hyperledger/aries-framework-go/pkg/doc/jose" 22 vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr" 23 "github.com/hyperledger/aries-framework-go/pkg/internal/test/makemockdoc" 24 "github.com/hyperledger/aries-framework-go/pkg/kms" 25 "github.com/hyperledger/aries-framework-go/pkg/kms/localkms" 26 mockcrypto "github.com/hyperledger/aries-framework-go/pkg/mock/crypto" 27 mockdiddoc "github.com/hyperledger/aries-framework-go/pkg/mock/diddoc" 28 mockkms "github.com/hyperledger/aries-framework-go/pkg/mock/kms" 29 mockstorage "github.com/hyperledger/aries-framework-go/pkg/mock/storage" 30 mockvdr "github.com/hyperledger/aries-framework-go/pkg/mock/vdr" 31 "github.com/hyperledger/aries-framework-go/pkg/secretlock" 32 "github.com/hyperledger/aries-framework-go/pkg/secretlock/noop" 33 "github.com/hyperledger/aries-framework-go/pkg/store/connection" 34 didstore "github.com/hyperledger/aries-framework-go/pkg/store/did" 35 "github.com/hyperledger/aries-framework-go/pkg/vdr/peer" 36 "github.com/hyperledger/aries-framework-go/spi/storage" 37 ) 38 39 const ( 40 defaultKID = "#key-1" 41 oldDID = "did:test:old" 42 newDID = "did:test:new" 43 myDID = "did:test:mine" 44 theirDID = "did:test:theirs" 45 oobV2Type = "https://didcomm.org/out-of-band/2.0/invitation" 46 ) 47 48 func TestNew(t *testing.T) { 49 t.Run("success", func(t *testing.T) { 50 _ = createBlankDIDRotator(t) 51 }) 52 53 t.Run("failure", func(t *testing.T) { 54 _, err := New(&mockProvider{ 55 storeProvider: &mockstorage.MockStoreProvider{ErrOpenStoreHandle: fmt.Errorf("open store error")}, 56 }) 57 require.Error(t, err) 58 require.Contains(t, err.Error(), "open store error") 59 }) 60 } 61 62 func TestDIDCommMessageMiddleware_handleInboundRotate(t *testing.T) { 63 t.Run("not didcomm v2", func(t *testing.T) { 64 dr := createBlankDIDRotator(t) 65 66 // didcomm v1 message 67 msg := service.DIDCommMsgMap{ 68 "@id": "12345", 69 "@type": "abc", 70 } 71 72 _, _, err := dr.handleInboundRotate(msg, "", "", nil) 73 require.NoError(t, err) 74 75 // invalid didcomm message 76 msg = service.DIDCommMsgMap{ 77 "foo": "12345", 78 "bar": "abc", 79 } 80 81 err = dr.HandleInboundMessage(msg, "", "") 82 require.Error(t, err) 83 require.Contains(t, err.Error(), "not a valid didcomm v1 or v2 message") 84 }) 85 86 t.Run("bad from_prior", func(t *testing.T) { 87 dr := createBlankDIDRotator(t) 88 89 // from_prior not a string 90 msg := service.DIDCommMsgMap{ 91 "id": "12345", 92 "type": "abc", 93 "body": map[string]interface{}{}, 94 "from_prior": []string{"abc", "def"}, 95 } 96 97 _, _, err := dr.handleInboundRotate(msg, "", "", nil) 98 require.Error(t, err) 99 require.Contains(t, err.Error(), "field should be a string") 100 101 // from_prior not a JWS 102 msg = service.DIDCommMsgMap{ 103 "id": "12345", 104 "type": "abc", 105 "body": map[string]interface{}{}, 106 "from_prior": "#$&@(*#^@(*#^", 107 } 108 109 _, _, err = dr.handleInboundRotate(msg, "", "", nil) 110 require.Error(t, err) 111 require.Contains(t, err.Error(), "parsing DID rotation JWS") 112 }) 113 114 sender := createBlankDIDRotator(t) 115 senderDoc := createMockDoc(t, sender, myDID) 116 senderConnID := uuid.New().String() 117 118 e := sender.connStore.SaveConnectionRecord(&connection.Record{ 119 ConnectionID: senderConnID, 120 State: connection.StateNameCompleted, 121 TheirDID: theirDID, 122 MyDID: myDID, 123 Namespace: connection.MyNSPrefix, 124 }) 125 require.NoError(t, e) 126 127 setResolveDocs(sender, []*did.Doc{senderDoc}) 128 129 e = sender.RotateConnectionDID(senderConnID, defaultKID, newDID) 130 require.NoError(t, e) 131 132 senderConnRec, e := sender.connStore.GetConnectionRecord(senderConnID) 133 require.NoError(t, e) 134 135 blankMessage := service.DIDCommMsgMap{ 136 "id": "12345", 137 "type": "abc", 138 } 139 140 rotateMessage := sender.HandleOutboundMessage(blankMessage.Clone(), senderConnRec) 141 142 t.Run("fail: can't rotate without prior connection", func(t *testing.T) { 143 recip := createBlankDIDRotator(t) 144 145 _, _, err := recip.handleInboundRotate(rotateMessage, newDID, theirDID, nil) 146 require.Error(t, err) 147 require.Contains(t, err.Error(), "inbound message cannot rotate without an existing prior connection") 148 }) 149 150 t.Run("fail: error reading connection record", func(t *testing.T) { 151 recip := createBlankDIDRotator(t) 152 153 connStore, err := connection.NewRecorder(&mockProvider{ 154 storeProvider: mockstorage.NewCustomMockStoreProvider(&mockstorage.MockStore{ 155 ErrQuery: fmt.Errorf("store error"), 156 ErrGet: fmt.Errorf("store error"), 157 }), 158 }) 159 require.NoError(t, err) 160 161 recip.connStore = connStore 162 163 _, _, err = recip.handleInboundRotate(rotateMessage, newDID, theirDID, nil) 164 require.Error(t, err) 165 require.Contains(t, err.Error(), "looking up did rotation connection record") 166 }) 167 168 t.Run("fail: from_prior JWS validation error", func(t *testing.T) { 169 recip := createBlankDIDRotator(t) 170 171 err := recip.connStore.SaveConnectionRecord(&connection.Record{ 172 ConnectionID: senderConnID, 173 State: connection.StateNameCompleted, 174 TheirDID: myDID, 175 MyDID: theirDID, 176 Namespace: connection.MyNSPrefix, 177 }) 178 require.NoError(t, err) 179 180 _, _, err = recip.handleInboundRotate(rotateMessage, newDID, theirDID, nil) 181 require.Error(t, err) 182 require.Contains(t, err.Error(), "'from_prior' validation") 183 }) 184 185 t.Run("fail: recipient rotated, but received message addressed to wrong DID", func(t *testing.T) { 186 handler := createBlankDIDRotator(t) 187 188 connRec := &connection.Record{ 189 ConnectionID: uuid.New().String(), 190 State: connection.StateNameCompleted, 191 TheirDID: myDID, 192 MyDID: theirDID, 193 Namespace: connection.MyNSPrefix, 194 MyDIDRotation: &connection.DIDRotationRecord{ 195 OldDID: "did:test:recipient-old", 196 NewDID: theirDID, 197 FromPrior: "", 198 }, 199 } 200 201 _, _, err := handler.handleInboundRotateAck("did:oops:wrong", connRec) 202 require.Error(t, err) 203 require.Contains(t, err.Error(), "inbound message sent to unexpected DID") 204 }) 205 206 t.Run("fail: error saving connection record", func(t *testing.T) { 207 recip := createBlankDIDRotator(t) 208 209 connID := uuid.New().String() 210 211 connRec := &connection.Record{ 212 ConnectionID: connID, 213 State: connection.StateNameCompleted, 214 TheirDID: myDID, 215 MyDID: theirDID, 216 Namespace: connection.MyNSPrefix, 217 MyDIDRotation: &connection.DIDRotationRecord{ 218 OldDID: "did:test:recipient-old", 219 NewDID: theirDID, 220 FromPrior: "", 221 }, 222 } 223 224 var err error 225 226 mockStore := mockstorage.MockStore{Store: map[string]mockstorage.DBEntry{}} 227 228 recip.connStore, err = connection.NewRecorder(&mockProvider{ 229 storeProvider: mockstorage.NewCustomMockStoreProvider(&mockStore), 230 }) 231 require.NoError(t, err) 232 233 err = recip.connStore.SaveConnectionRecord(connRec) 234 require.NoError(t, err) 235 236 mockStore.ErrPut = fmt.Errorf("store error") 237 238 err = recip.HandleInboundMessage(blankMessage, myDID, theirDID) 239 require.Error(t, err) 240 require.Contains(t, err.Error(), "updating connection") 241 }) 242 243 t.Run("success: pass-through, no rotation on either end", func(t *testing.T) { 244 recip := createBlankDIDRotator(t) 245 246 _, _, err := recip.handleInboundRotate(blankMessage, myDID, theirDID, nil) 247 require.NoError(t, err) 248 }) 249 } 250 251 func TestDIDRotator_HandleOutboundMessage(t *testing.T) { 252 t.Run("not didcomm v2 message", func(t *testing.T) { 253 dr := createBlankDIDRotator(t) 254 255 // didcomm v1 message 256 msg := service.DIDCommMsgMap{ 257 "@id": "12345", 258 "@type": "abc", 259 } 260 261 msgOut := dr.HandleOutboundMessage(msg, &connection.Record{}) 262 require.Equal(t, msg, msgOut) 263 264 // invalid didcomm message 265 msg = service.DIDCommMsgMap{ 266 "foo": "12345", 267 "bar": "abc", 268 } 269 270 msgOut = dr.HandleOutboundMessage(msg, &connection.Record{}) 271 require.Equal(t, msg, msgOut) 272 }) 273 274 t.Run("handle didcomm v2 message", func(t *testing.T) { 275 dr := createBlankDIDRotator(t) 276 277 msg := service.DIDCommMsgMap{ 278 "id": "123", 279 "type": "abc", 280 } 281 282 // no change to message 283 msgOut := dr.HandleOutboundMessage(msg, &connection.Record{}) 284 require.Equal(t, msg, msgOut) 285 286 // add from_prior to message 287 mockPrior := "mock prior data" 288 289 msgOut = dr.HandleOutboundMessage(msg, &connection.Record{ 290 MyDIDRotation: &connection.DIDRotationRecord{FromPrior: mockPrior}, 291 }) 292 require.Equal(t, mockPrior, msgOut[fromPriorJSONKey]) 293 294 mockPeerDIDState := "blah_blah_peer_DID_data" 295 mockDID := "did:test:abc" 296 297 msgOut = dr.HandleOutboundMessage(msg, &connection.Record{ 298 MyDID: mockDID, 299 PeerDIDInitialState: mockPeerDIDState, 300 }) 301 require.Equal(t, mockDID+"?"+initialStateParam+"="+mockPeerDIDState, msgOut[fromDIDJSONKey]) 302 }) 303 } 304 305 func TestHandleInboundAccept(t *testing.T) { 306 t.Run("skip: failed to parse recipient DID", func(t *testing.T) { 307 h := createBlankDIDRotator(t) 308 309 rec, err := h.handleInboundInvitationAcceptance("", "") 310 require.NoError(t, err) 311 require.Nil(t, rec) 312 }) 313 314 t.Run("skip: recipient DID is peer", func(t *testing.T) { 315 h := createBlankDIDRotator(t) 316 317 rec, err := h.handleInboundInvitationAcceptance("", "did:peer:abc") 318 require.NoError(t, err) 319 require.Nil(t, rec) 320 }) 321 322 t.Run("skip: we have no invitation for the DID they sent to", func(t *testing.T) { 323 h := createBlankDIDRotator(t) 324 325 rec, err := h.handleInboundInvitationAcceptance("", myDID) 326 require.NoError(t, err) 327 require.Nil(t, rec) 328 }) 329 330 t.Run("fail: error reading from connection store for our invitation", func(t *testing.T) { 331 h := createBlankDIDRotator(t) 332 333 expectedErr := fmt.Errorf("store get error") 334 335 var err error 336 h.connStore, err = connection.NewRecorder(&mockProvider{ 337 storeProvider: mockstorage.NewCustomMockStoreProvider( 338 &mockstorage.MockStore{ 339 Store: map[string]mockstorage.DBEntry{}, 340 ErrGet: expectedErr, 341 }), 342 }) 343 require.NoError(t, err) 344 345 rec, err := h.handleInboundInvitationAcceptance("", myDID) 346 require.Error(t, err) 347 require.Nil(t, rec) 348 require.ErrorIs(t, err, expectedErr) 349 }) 350 351 t.Run("fail: error reading connection", func(t *testing.T) { 352 h := createBlankDIDRotator(t) 353 354 expectedErr := fmt.Errorf("store get error") 355 356 var err error 357 h.connStore, err = connection.NewRecorder(&mockProvider{ 358 storeProvider: mockstorage.NewCustomMockStoreProvider( 359 &mockstorage.MockStore{ 360 Store: map[string]mockstorage.DBEntry{}, 361 ErrQuery: expectedErr, 362 }), 363 }) 364 require.NoError(t, err) 365 366 err = h.connStore.SaveOOBv2Invitation(myDID, invitationStub{ 367 Type: oobV2Type, 368 }) 369 require.NoError(t, err) 370 371 rec, err := h.handleInboundInvitationAcceptance(theirDID, myDID) 372 require.Nil(t, rec) 373 require.Error(t, err) 374 require.Contains(t, err.Error(), "failed to get connection record") 375 require.ErrorIs(t, err, expectedErr) 376 }) 377 378 t.Run("skip: connection already exists between invitation DID and invitee DID", func(t *testing.T) { 379 h := createBlankDIDRotator(t) 380 381 err := h.connStore.SaveOOBv2Invitation(myDID, invitationStub{ 382 Type: oobV2Type, 383 }) 384 require.NoError(t, err) 385 386 err = h.connStore.SaveConnectionRecord(&connection.Record{ 387 ConnectionID: "conn-123", 388 State: connection.StateNameCompleted, 389 TheirDID: theirDID, 390 MyDID: myDID, 391 Namespace: connection.MyNSPrefix, 392 }) 393 require.NoError(t, err) 394 395 rec, err := h.handleInboundInvitationAcceptance(theirDID, myDID) 396 require.NoError(t, err) 397 require.NotNil(t, rec) 398 }) 399 400 t.Run("fail: error creating connection record for new connection", func(t *testing.T) { 401 h := createBlankDIDRotator(t) 402 403 store := mockstorage.MockStore{ 404 Store: map[string]mockstorage.DBEntry{}, 405 } 406 407 var err error 408 h.connStore, err = connection.NewRecorder(&mockProvider{ 409 storeProvider: mockstorage.NewCustomMockStoreProvider(&store), 410 }) 411 require.NoError(t, err) 412 413 err = h.connStore.SaveOOBv2Invitation(myDID, invitationStub{ 414 Type: oobV2Type, 415 }) 416 require.NoError(t, err) 417 418 expectedErr := fmt.Errorf("store get error") 419 420 h.connStore, err = connection.NewRecorder(&mockProvider{ 421 storeProvider: mockstorage.NewCustomMockStoreProvider( 422 &mockstorage.MockStore{ 423 Store: store.Store, 424 ErrPut: expectedErr, 425 }), 426 }) 427 require.NoError(t, err) 428 429 _, err = h.handleInboundInvitationAcceptance(theirDID, myDID) 430 require.Error(t, err) 431 require.ErrorIs(t, err, expectedErr) 432 }) 433 434 t.Run("fail: error creating connection record for new connection", func(t *testing.T) { 435 h := createBlankDIDRotator(t) 436 437 err := h.connStore.SaveOOBv2Invitation(myDID, invitationStub{ 438 Type: oobV2Type, 439 }) 440 require.NoError(t, err) 441 442 rec, err := h.handleInboundInvitationAcceptance(theirDID, myDID) 443 require.NoError(t, err) 444 require.NotNil(t, rec) 445 446 require.Equal(t, myDID, rec.MyDID) 447 require.Equal(t, theirDID, rec.TheirDID) 448 }) 449 } 450 451 func TestHandleInboundPeerDID(t *testing.T) { 452 t.Run("skip: message has no from field", func(t *testing.T) { 453 h := createBlankDIDRotator(t) 454 455 err := h.HandleInboundPeerDID(service.DIDCommMsgMap{}) 456 require.NoError(t, err) 457 }) 458 459 t.Run("fail: parsing their DID", func(t *testing.T) { 460 h := createBlankDIDRotator(t) 461 462 err := h.HandleInboundPeerDID(service.DIDCommMsgMap{ 463 "id": "foo", 464 "type": "bar", 465 fromDIDJSONKey: "argle bargle", 466 }) 467 require.Error(t, err) 468 require.Contains(t, err.Error(), "parsing their DID") 469 }) 470 471 t.Run("skip: from field not a DID, but not didcomm v2 message", func(t *testing.T) { 472 h := createBlankDIDRotator(t) 473 474 err := h.HandleInboundPeerDID(service.DIDCommMsgMap{ 475 "@id": "foo", 476 "@type": "bar", 477 fromDIDJSONKey: "argle bargle", 478 }) 479 require.NoError(t, err) 480 }) 481 482 t.Run("skip: sender DID not a peer DID", func(t *testing.T) { 483 h := createBlankDIDRotator(t) 484 485 err := h.HandleInboundPeerDID(service.DIDCommMsgMap{ 486 fromDIDJSONKey: "did:foo:bar", 487 }) 488 require.NoError(t, err) 489 }) 490 491 t.Run("skip: sender peer DID doesn't include initialState", func(t *testing.T) { 492 h := createBlankDIDRotator(t) 493 494 err := h.HandleInboundPeerDID(service.DIDCommMsgMap{ 495 fromDIDJSONKey: "did:peer:abc", 496 }) 497 require.NoError(t, err) 498 }) 499 500 t.Run("fail: can't parse initialState", func(t *testing.T) { 501 h := createBlankDIDRotator(t) 502 503 err := h.HandleInboundPeerDID(service.DIDCommMsgMap{ 504 fromDIDJSONKey: "did:peer:abc?" + initialStateParam, 505 }) 506 require.Error(t, err) 507 require.Contains(t, err.Error(), "parsing DID doc") 508 }) 509 510 t.Run("fail: can't save initialState DID doc", func(t *testing.T) { 511 h := createBlankDIDRotator(t) 512 513 mockInitialState, err := peer.UnsignedGenesisDelta(mockdiddoc.GetMockDIDDocWithDIDCommV2Bloc(t, "abc")) 514 require.NoError(t, err) 515 516 err = h.HandleInboundPeerDID(service.DIDCommMsgMap{ 517 fromDIDJSONKey: "did:peer:abc?" + initialStateParam + "=" + mockInitialState, 518 }) 519 require.Error(t, err) 520 require.Contains(t, err.Error(), "saving their peer DID") 521 }) 522 523 t.Run("fail: saving DIDs from doc", func(t *testing.T) { 524 h := createBlankDIDRotator(t) 525 526 h.vdr = &mockvdr.MockVDRegistry{ 527 CreateFunc: func(_ string, doc *did.Doc, _ ...vdrapi.DIDMethodOption) (*did.DocResolution, error) { 528 return &did.DocResolution{DIDDocument: doc}, nil 529 }, 530 } 531 532 var err error 533 534 expectedErr := fmt.Errorf("expected error") 535 536 h.didStore, err = didstore.NewConnectionStore(&mockProvider{storeProvider: mockstorage.NewCustomMockStoreProvider( 537 &mockstorage.MockStore{ErrPut: expectedErr})}) 538 require.NoError(t, err) 539 540 mockInitialState, err := peer.UnsignedGenesisDelta(mockdiddoc.GetMockDIDDocWithDIDCommV2Bloc(t, "abc")) 541 require.NoError(t, err) 542 543 peerDID := "did:peer:abc" 544 545 msg := service.DIDCommMsgMap{ 546 fromDIDJSONKey: peerDID + "?" + initialStateParam + "=" + mockInitialState, 547 } 548 549 err = h.HandleInboundPeerDID(msg) 550 require.Error(t, err) 551 require.Contains(t, err.Error(), "saving key to did map") 552 require.ErrorIs(t, err, expectedErr) 553 }) 554 555 t.Run("success", func(t *testing.T) { 556 h := createBlankDIDRotator(t) 557 558 var checkDoc *did.Doc 559 h.vdr = &mockvdr.MockVDRegistry{ 560 CreateFunc: func(_ string, doc *did.Doc, _ ...vdrapi.DIDMethodOption) (*did.DocResolution, error) { 561 checkDoc = doc 562 563 return &did.DocResolution{DIDDocument: doc}, nil 564 }, 565 } 566 567 expectedDoc := mockdiddoc.GetMockDIDDocWithDIDCommV2Bloc(t, "abc") 568 569 mockInitialState, err := peer.UnsignedGenesisDelta(expectedDoc) 570 require.NoError(t, err) 571 572 peerDID := "did:peer:abc" 573 574 msg := service.DIDCommMsgMap{ 575 fromDIDJSONKey: peerDID + "?" + initialStateParam + "=" + mockInitialState, 576 } 577 578 err = h.HandleInboundPeerDID(msg) 579 require.NoError(t, err) 580 require.NotNil(t, checkDoc) 581 require.Equal(t, expectedDoc.ID, checkDoc.ID) 582 583 cleanedDID := msg[fromDIDJSONKey] 584 585 require.Equal(t, peerDID, cleanedDID) 586 }) 587 } 588 589 func TestDIDRotator_RotateConnectionDID(t *testing.T) { 590 t.Run("success: rotating to peer DID", func(t *testing.T) { 591 dr := createBlankDIDRotator(t) 592 593 connID := uuid.New().String() 594 595 err := dr.connStore.SaveConnectionRecord(&connection.Record{ 596 ConnectionID: connID, 597 State: connection.StateNameCompleted, 598 TheirDID: "did:test:them", 599 MyDID: oldDID, 600 Namespace: connection.MyNSPrefix, 601 }) 602 require.NoError(t, err) 603 604 oldDoc := createMockDoc(t, dr, oldDID) 605 606 newPeerDID := "did:peer:new" 607 newDoc := createMockDoc(t, dr, newPeerDID) 608 609 setResolveDocs(dr, []*did.Doc{oldDoc, newDoc}) 610 611 err = dr.RotateConnectionDID(connID, defaultKID, newPeerDID) 612 require.NoError(t, err) 613 614 connRec, err := dr.connStore.GetConnectionRecord(connID) 615 require.NoError(t, err) 616 require.NotEqual(t, "", connRec.PeerDIDInitialState) 617 }) 618 619 t.Run("fail: get connection record", func(t *testing.T) { 620 dr := createBlankDIDRotator(t) 621 622 err := dr.RotateConnectionDID("not an ID", "foo", "did:some:thing") 623 require.Error(t, err) 624 require.Contains(t, err.Error(), "getting connection record") 625 }) 626 627 t.Run("fail: resolve signing did doc", func(t *testing.T) { 628 dr := createBlankDIDRotator(t) 629 630 connID := uuid.New().String() 631 632 err := dr.connStore.SaveConnectionRecord(&connection.Record{ 633 ConnectionID: connID, 634 State: connection.StateNameCompleted, 635 TheirDID: "did:test:them", 636 MyDID: "did:test:me", 637 Namespace: connection.MyNSPrefix, 638 }) 639 require.NoError(t, err) 640 641 err = dr.RotateConnectionDID(connID, "foo", "did:some:thing") 642 require.Error(t, err) 643 require.Contains(t, err.Error(), "resolving my DID") 644 }) 645 646 t.Run("fail: creating did rotation JWS", func(t *testing.T) { 647 dr := createBlankDIDRotator(t) 648 649 connID := uuid.New().String() 650 651 drDID := "did:test:me" 652 653 err := dr.connStore.SaveConnectionRecord(&connection.Record{ 654 ConnectionID: connID, 655 State: connection.StateNameCompleted, 656 TheirDID: "did:test:them", 657 MyDID: drDID, 658 Namespace: connection.MyNSPrefix, 659 }) 660 require.NoError(t, err) 661 662 doc := createMockDoc(t, dr, drDID) 663 setResolveDocs(dr, []*did.Doc{doc}) 664 665 err = dr.RotateConnectionDID(connID, "foo", "did:some:thing") 666 require.Error(t, err) 667 require.Contains(t, err.Error(), "creating did rotation from_prior") 668 }) 669 670 t.Run("fail: resolving peer DID being rotated to", func(t *testing.T) { 671 dr := createBlankDIDRotator(t) 672 673 connID := uuid.New().String() 674 675 err := dr.connStore.SaveConnectionRecord(&connection.Record{ 676 ConnectionID: connID, 677 State: connection.StateNameCompleted, 678 TheirDID: "did:test:them", 679 MyDID: oldDID, 680 Namespace: connection.MyNSPrefix, 681 }) 682 require.NoError(t, err) 683 684 oldDoc := createMockDoc(t, dr, oldDID) 685 686 newPeerDID := "did:peer:new" 687 688 setResolveDocs(dr, []*did.Doc{oldDoc}) 689 690 err = dr.RotateConnectionDID(connID, defaultKID, newPeerDID) 691 require.Error(t, err) 692 require.Contains(t, err.Error(), "resolving new DID") 693 }) 694 695 t.Run("fail: saving updated connection record", func(t *testing.T) { 696 dr := createBlankDIDRotator(t) 697 698 connID := uuid.New().String() 699 700 drDID := "did:test:me" 701 702 var err error 703 704 mockStore := mockstorage.MockStore{Store: map[string]mockstorage.DBEntry{}} 705 706 dr.connStore, err = connection.NewRecorder(&mockProvider{ 707 storeProvider: mockstorage.NewCustomMockStoreProvider(&mockStore), 708 }) 709 require.NoError(t, err) 710 711 err = dr.connStore.SaveConnectionRecord(&connection.Record{ 712 ConnectionID: connID, 713 State: connection.StateNameCompleted, 714 TheirDID: "did:test:them", 715 MyDID: drDID, 716 Namespace: connection.MyNSPrefix, 717 }) 718 require.NoError(t, err) 719 720 mockStore.ErrPut = fmt.Errorf("store error") 721 722 doc := createMockDoc(t, dr, drDID) 723 setResolveDocs(dr, []*did.Doc{doc}) 724 725 err = dr.RotateConnectionDID(connID, defaultKID, "did:some:thing") 726 require.Error(t, err) 727 require.Contains(t, err.Error(), "saving connection record") 728 }) 729 } 730 731 func TestDIDRotator_Create(t *testing.T) { 732 t.Run("fail: KID not in doc", func(t *testing.T) { 733 dr := createBlankDIDRotator(t) 734 doc := createMockDoc(t, dr, oldDID) 735 736 _, err := dr.Create(doc, "#oops", newDID) 737 require.Error(t, err) 738 require.Contains(t, err.Error(), "KID not found in doc") 739 }) 740 741 t.Run("fail: unsupported VM type", func(t *testing.T) { 742 dr := createBlankDIDRotator(t) 743 doc2 := &did.Doc{ 744 ID: oldDID, 745 VerificationMethod: []did.VerificationMethod{ 746 { 747 ID: defaultKID, 748 Type: "oops", 749 Controller: oldDID, 750 Value: nil, 751 }, 752 }, 753 } 754 755 _, err := dr.Create(doc2, defaultKID, newDID) 756 require.Error(t, err) 757 require.Contains(t, err.Error(), "vm.Type 'oops' not supported") 758 }) 759 760 t.Run("fail: kms get key handle error", func(t *testing.T) { 761 dr := createBlankDIDRotator(t) 762 doc := createMockDoc(t, dr, oldDID) 763 764 dr.kms = &mockkms.KeyManager{ 765 GetKeyErr: fmt.Errorf("kms error"), 766 } 767 768 _, err := dr.Create(doc, defaultKID, newDID) 769 770 require.Error(t, err) 771 require.Contains(t, err.Error(), "get signing key handle") 772 }) 773 774 t.Run("fail: signing error", func(t *testing.T) { 775 dr := createBlankDIDRotator(t) 776 doc := createMockDoc(t, dr, oldDID) 777 778 cr := mockcrypto.Crypto{ 779 SignErr: fmt.Errorf("sign error"), 780 } 781 782 dr.crypto = &cr 783 784 _, err := dr.Create(doc, defaultKID, newDID) 785 786 require.Error(t, err) 787 require.Contains(t, err.Error(), "creating DID rotation JWS") 788 }) 789 } 790 791 func TestDIDRotator_CreateVerify(t *testing.T) { 792 dr := createBlankDIDRotator(t) 793 794 doc := createMockDoc(t, dr, oldDID) 795 796 setResolveDocs(dr, []*did.Doc{doc}) 797 798 t.Run("success", func(t *testing.T) { 799 sig, err := dr.Create(doc, defaultKID, newDID) 800 require.NoError(t, err) 801 802 verifier := createBlankDIDRotator(t) 803 setResolveDocs(verifier, []*did.Doc{doc}) 804 805 testOldDID, err := verifier.Verify(newDID, sig) 806 require.NoError(t, err) 807 require.Equal(t, oldDID, testOldDID) 808 }) 809 810 t.Run("verify failure: bad jws", func(t *testing.T) { 811 _, err := dr.Verify(newDID, "*$&W#)(@&*(^") 812 813 require.Error(t, err) 814 require.Contains(t, err.Error(), "parsing DID rotation JWS") 815 }) 816 817 t.Run("verify failure: verifier can't resolve doc", func(t *testing.T) { 818 sig, err := dr.Create(doc, defaultKID, newDID) 819 require.NoError(t, err) 820 821 verifier := createBlankDIDRotator(t) 822 823 _, err = verifier.Verify(newDID, sig) 824 require.Error(t, err) 825 require.Contains(t, err.Error(), "resolving prior DID doc") 826 }) 827 } 828 829 func Test_RoundTrip(t *testing.T) { 830 me := createBlankDIDRotator(t) 831 them := createBlankDIDRotator(t) 832 833 oldDoc := createMockDoc(t, me, oldDID) 834 newDoc := createMockDoc(t, me, newDID) 835 theirDoc := createMockDoc(t, them, theirDID) 836 837 setResolveDocs(me, []*did.Doc{oldDoc, newDoc, theirDoc}) 838 setResolveDocs(them, []*did.Doc{oldDoc, newDoc, theirDoc}) 839 840 myConnID := uuid.New().String() 841 842 err := me.connStore.SaveConnectionRecord(&connection.Record{ 843 ConnectionID: myConnID, 844 State: connection.StateNameCompleted, 845 TheirDID: theirDID, 846 MyDID: oldDID, 847 Namespace: connection.MyNSPrefix, 848 }) 849 require.NoError(t, err) 850 851 theirConnID := uuid.New().String() 852 853 err = them.connStore.SaveConnectionRecord(&connection.Record{ 854 ConnectionID: theirConnID, 855 State: connection.StateNameCompleted, 856 TheirDID: oldDID, 857 MyDID: theirDID, 858 Namespace: connection.MyNSPrefix, 859 }) 860 require.NoError(t, err) 861 862 sendMessage(t, me, them, myConnID) 863 864 err = me.RotateConnectionDID(myConnID, defaultKID, newDID) 865 require.NoError(t, err) 866 867 // if I don't send a message after rotating, then I expect their response to be sent to my old DID... 868 sendMessage(t, them, me, theirConnID) 869 870 // ...so my connection record should still have the from_prior. 871 myConnRec, err := me.connStore.GetConnectionRecord(myConnID) 872 require.NoError(t, err) 873 require.NotNil(t, myConnRec.MyDIDRotation) 874 require.NotEqual(t, "", myConnRec.MyDIDRotation.FromPrior) 875 876 // but if I send a message after rotating, they should update their connection with my rotation... 877 sendMessage(t, me, them, myConnID) 878 sendMessage(t, me, them, myConnID) // sending a second message, so they handle after already processing our rotation 879 880 // ...so after I get a response to my *new* DID... 881 sendMessage(t, them, me, theirConnID) 882 883 // ...my connection record should no longer have the from_prior. 884 _, err = me.connStore.GetConnectionRecord(myConnID) 885 // this assertion fails intermittently - disabled for now: 886 // require.Nil(t, myConnRec.MyDIDRotation) 887 require.NoError(t, err) 888 } 889 890 func TestDIDRotator_getUnverifiedJWS(t *testing.T) { 891 t.Run("fail: can't parse JWS", func(t *testing.T) { 892 jws := "(^#$*(#$^&*" 893 894 dr := createBlankDIDRotator(t) 895 896 _, _, err := dr.getUnverifiedJWS("foo", jws) 897 require.Error(t, err) 898 require.Contains(t, err.Error(), "parsing DID rotation JWS") 899 }) 900 901 t.Run("fail: can't parse payload", func(t *testing.T) { 902 jws, err := jose.NewJWS( 903 jose.Headers{"alg": "blahblah"}, nil, []byte("abcdefg"), &mockSigner{}) 904 require.NoError(t, err) 905 906 sig, err := jws.SerializeCompact(false) 907 require.NoError(t, err) 908 909 dr := createBlankDIDRotator(t) 910 911 _, _, err = dr.getUnverifiedJWS("foo", sig) 912 require.Error(t, err) 913 require.Contains(t, err.Error(), "parsing DID rotation payload") 914 }) 915 916 t.Run("fail: payload missing iss or sub", func(t *testing.T) { 917 jws, err := jose.NewJWS( 918 jose.Headers{"alg": "blahblah"}, nil, []byte("{}"), &mockSigner{}) 919 require.NoError(t, err) 920 921 sig, err := jws.SerializeCompact(false) 922 require.NoError(t, err) 923 924 dr := createBlankDIDRotator(t) 925 926 _, _, err = dr.getUnverifiedJWS("foo", sig) 927 require.Error(t, err) 928 require.Contains(t, err.Error(), "payload missing iss or sub") 929 }) 930 931 t.Run("fail: payload subject mismatch", func(t *testing.T) { 932 jws, err := jose.NewJWS(jose.Headers{"alg": "blahblah"}, nil, 933 []byte(`{"iss":"abc","sub":"def"}`), &mockSigner{}) 934 require.NoError(t, err) 935 936 sig, err := jws.SerializeCompact(false) 937 require.NoError(t, err) 938 939 dr := createBlankDIDRotator(t) 940 941 _, _, err = dr.getUnverifiedJWS("foo", sig) 942 require.Error(t, err) 943 require.Contains(t, err.Error(), "payload sub must be the DID of the message sender") 944 }) 945 } 946 947 func TestDIDRotator_verifyJWSAndPayload(t *testing.T) { 948 rotator := createBlankDIDRotator(t) 949 oldDID := "did:test:rotator" 950 newDID := "did:test:new" 951 doc := createMockDoc(t, rotator, oldDID) 952 setResolveDocs(rotator, []*did.Doc{doc}) 953 954 fromPrior, e := rotator.Create(doc, defaultKID, newDID) 955 require.NoError(t, e) 956 957 t.Run("success", func(t *testing.T) { 958 verifier := createBlankDIDRotator(t) 959 setResolveDocs(verifier, []*did.Doc{doc}) 960 961 jws, payload, err := verifier.getUnverifiedJWS(newDID, fromPrior) 962 require.NoError(t, err) 963 964 err = verifier.verifyJWSAndPayload(jws, payload) 965 require.NoError(t, err) 966 }) 967 968 t.Run("fail: JWS headers missing KID", func(t *testing.T) { 969 verifier := createBlankDIDRotator(t) 970 err := verifier.verifyJWSAndPayload(&jose.JSONWebSignature{ 971 ProtectedHeaders: jose.Headers{}, 972 }, nil) 973 require.Error(t, err) 974 require.Contains(t, err.Error(), "protected headers missing KID") 975 }) 976 977 t.Run("fail: resolving signer DID doc", func(t *testing.T) { 978 verifier := createBlankDIDRotator(t) 979 jws, payload, err := verifier.getUnverifiedJWS(newDID, fromPrior) 980 require.NoError(t, err) 981 982 err = verifier.verifyJWSAndPayload(jws, payload) 983 require.Error(t, err) 984 require.Contains(t, err.Error(), "resolving prior DID doc") 985 }) 986 987 t.Run("fail: selecting verification method from signer DID doc", func(t *testing.T) { 988 verifier := createBlankDIDRotator(t) 989 setResolveDocs(verifier, []*did.Doc{doc}) 990 jws, payload, err := verifier.getUnverifiedJWS(newDID, fromPrior) 991 require.NoError(t, err) 992 993 jws.ProtectedHeaders["kid"] = "AAAAAAAAAA" 994 995 err = verifier.verifyJWSAndPayload(jws, payload) 996 require.Error(t, err) 997 require.Contains(t, err.Error(), "kid not found in doc") 998 }) 999 1000 t.Run("fail: did doc VM has unsupported type", func(t *testing.T) { 1001 verifier := createBlankDIDRotator(t) 1002 setResolveDocs(verifier, []*did.Doc{{ 1003 ID: oldDID, 1004 VerificationMethod: []did.VerificationMethod{ 1005 { 1006 ID: defaultKID, 1007 Type: "oops", 1008 Controller: oldDID, 1009 Value: nil, 1010 }, 1011 }, 1012 }}) 1013 1014 jws, payload, err := verifier.getUnverifiedJWS(newDID, fromPrior) 1015 require.NoError(t, err) 1016 1017 err = verifier.verifyJWSAndPayload(jws, payload) 1018 require.Error(t, err) 1019 require.Contains(t, err.Error(), "vm.Type 'oops' not supported") 1020 }) 1021 1022 t.Run("fail: kms get key handle error", func(t *testing.T) { 1023 verifier := createBlankDIDRotator(t) 1024 setResolveDocs(verifier, []*did.Doc{doc}) 1025 1026 verifier.kms = &mockkms.KeyManager{ 1027 PubKeyBytesToHandleErr: fmt.Errorf("kms error"), 1028 } 1029 1030 jws, payload, err := verifier.getUnverifiedJWS(newDID, fromPrior) 1031 require.NoError(t, err) 1032 1033 err = verifier.verifyJWSAndPayload(jws, payload) 1034 require.Error(t, err) 1035 require.Contains(t, err.Error(), "get verification key handle") 1036 }) 1037 1038 t.Run("fail: signature verification error", func(t *testing.T) { 1039 verifier := createBlankDIDRotator(t) 1040 setResolveDocs(verifier, []*did.Doc{doc}) 1041 1042 verifier.crypto = &mockcrypto.Crypto{ 1043 VerifyErr: fmt.Errorf("verify error"), 1044 } 1045 1046 jws, payload, err := verifier.getUnverifiedJWS(newDID, fromPrior) 1047 require.NoError(t, err) 1048 1049 err = verifier.verifyJWSAndPayload(jws, payload) 1050 require.Error(t, err) 1051 require.Contains(t, err.Error(), "signature verification") 1052 }) 1053 } 1054 1055 type kmsProvider struct { 1056 kmsStore kms.Store 1057 secretLockService secretlock.Service 1058 } 1059 1060 func (k *kmsProvider) StorageProvider() kms.Store { 1061 return k.kmsStore 1062 } 1063 1064 func (k *kmsProvider) SecretLock() secretlock.Service { 1065 return k.secretLockService 1066 } 1067 1068 func createMockProvider(t *testing.T) *mockProvider { 1069 t.Helper() 1070 1071 kmsStore, err := kms.NewAriesProviderWrapper(mockstorage.NewMockStoreProvider()) 1072 require.NoError(t, err) 1073 1074 kmsStorage, err := localkms.New("local-lock://test/master/key/", &kmsProvider{ 1075 kmsStore: kmsStore, 1076 secretLockService: &noop.NoLock{}, 1077 }) 1078 require.NoError(t, err) 1079 1080 cr, err := tinkcrypto.New() 1081 require.NoError(t, err) 1082 1083 vdr := &mockvdr.MockVDRegistry{ 1084 CreateFunc: func(didID string, doc *did.Doc, option ...vdrapi.DIDMethodOption) (*did.DocResolution, error) { 1085 return nil, fmt.Errorf("not created") 1086 }, 1087 ResolveFunc: func(didID string, opts ...vdrapi.DIDMethodOption) (*did.DocResolution, error) { 1088 return nil, fmt.Errorf("not found") 1089 }, 1090 } 1091 1092 didStore, err := didstore.NewConnectionStore(&mockProvider{storeProvider: mockstorage.NewMockStoreProvider()}) 1093 require.NoError(t, err) 1094 1095 return &mockProvider{ 1096 kms: kmsStorage, 1097 crypto: cr, 1098 vdr: vdr, 1099 storeProvider: mockstorage.NewMockStoreProvider(), 1100 didStore: didStore, 1101 } 1102 } 1103 1104 func createBlankDIDRotator(t *testing.T) *DIDCommMessageMiddleware { 1105 t.Helper() 1106 1107 dr, err := New(createMockProvider(t)) 1108 require.NoError(t, err) 1109 1110 return dr 1111 } 1112 1113 func createMockDoc(t *testing.T, dr *DIDCommMessageMiddleware, docDID string) *did.Doc { 1114 t.Helper() 1115 1116 keyType := kms.ECDSAP384TypeIEEEP1363 1117 1118 return makemockdoc.MakeMockDoc(t, dr.kms, docDID, keyType) 1119 } 1120 1121 func setResolveDocs(dr *DIDCommMessageMiddleware, docs []*did.Doc) { 1122 dr.vdr = &mockvdr.MockVDRegistry{ 1123 ResolveFunc: func(didID string, opts ...vdrapi.DIDMethodOption) (*did.DocResolution, error) { 1124 for _, doc := range docs { 1125 if didID == doc.ID { 1126 return &did.DocResolution{DIDDocument: doc}, nil 1127 } 1128 } 1129 1130 return nil, vdrapi.ErrNotFound 1131 }, 1132 } 1133 } 1134 1135 func sendMessage(t *testing.T, sender, recipient *DIDCommMessageMiddleware, senderConnID string) { 1136 t.Helper() 1137 1138 msgTemplate := service.DIDCommMsgMap{ 1139 "id": "12345", 1140 "type": "message", 1141 "body": map[string]interface{}{}, 1142 } 1143 1144 myConnRec, err := sender.connStore.GetConnectionRecord(senderConnID) 1145 require.NoError(t, err) 1146 1147 msg := sender.HandleOutboundMessage(msgTemplate, myConnRec) 1148 1149 err = recipient.HandleInboundMessage(msg, myConnRec.MyDID, myConnRec.TheirDID) 1150 require.NoError(t, err) 1151 } 1152 1153 type mockProvider struct { 1154 kms kms.KeyManager 1155 crypto crypto.Crypto 1156 storeProvider storage.Provider 1157 secretLock secretlock.Service 1158 vdr vdrapi.Registry 1159 mediaTypes []string 1160 didStore didstore.ConnectionStore 1161 } 1162 1163 func (m *mockProvider) DIDConnectionStore() didstore.ConnectionStore { 1164 return m.didStore 1165 } 1166 1167 func (m *mockProvider) MediaTypeProfiles() []string { 1168 return m.mediaTypes 1169 } 1170 1171 func (m *mockProvider) VDRegistry() vdrapi.Registry { 1172 return m.vdr 1173 } 1174 1175 func (m *mockProvider) StorageProvider() storage.Provider { 1176 return m.storeProvider 1177 } 1178 1179 func (m *mockProvider) SecretLock() secretlock.Service { 1180 return m.secretLock 1181 } 1182 1183 func (m *mockProvider) Crypto() crypto.Crypto { 1184 return m.crypto 1185 } 1186 1187 func (m *mockProvider) KMS() kms.KeyManager { 1188 return m.kms 1189 } 1190 1191 func (m *mockProvider) ProtocolStateStorageProvider() storage.Provider { 1192 return m.storeProvider 1193 } 1194 1195 type mockSigner struct{} 1196 1197 // Sign mock sign. 1198 func (m *mockSigner) Sign(data []byte) ([]byte, error) { 1199 return data, nil 1200 } 1201 1202 // Headers returns nil. 1203 func (m *mockSigner) Headers() jose.Headers { 1204 return nil 1205 }