github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/didexchange/service.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package didexchange 8 9 import ( 10 "encoding/json" 11 "errors" 12 "fmt" 13 "strings" 14 15 "github.com/google/uuid" 16 17 "github.com/hyperledger/aries-framework-go/pkg/common/log" 18 "github.com/hyperledger/aries-framework-go/pkg/common/model" 19 "github.com/hyperledger/aries-framework-go/pkg/crypto" 20 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" 21 "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" 22 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" 23 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator" 24 "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" 25 "github.com/hyperledger/aries-framework-go/pkg/doc/did" 26 vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr" 27 "github.com/hyperledger/aries-framework-go/pkg/internal/logutil" 28 "github.com/hyperledger/aries-framework-go/pkg/kms" 29 "github.com/hyperledger/aries-framework-go/pkg/store/connection" 30 didstore "github.com/hyperledger/aries-framework-go/pkg/store/did" 31 "github.com/hyperledger/aries-framework-go/pkg/vdr" 32 "github.com/hyperledger/aries-framework-go/spi/storage" 33 ) 34 35 var logger = log.New("aries-framework/did-exchange/service") 36 37 const ( 38 // DIDExchange did exchange protocol. 39 DIDExchange = "didexchange" 40 // PIURI is the did-exchange protocol identifier URI. 41 PIURI = "https://didcomm.org/didexchange/1.0" 42 // InvitationMsgType defines the did-exchange invite message type. 43 InvitationMsgType = PIURI + "/invitation" 44 // RequestMsgType defines the did-exchange request message type. 45 RequestMsgType = PIURI + "/request" 46 // ResponseMsgType defines the did-exchange response message type. 47 ResponseMsgType = PIURI + "/response" 48 // AckMsgType defines the did-exchange ack message type. 49 AckMsgType = PIURI + "/ack" 50 // CompleteMsgType defines the did-exchange complete message type. 51 CompleteMsgType = PIURI + "/complete" 52 // oobMsgType is the internal message type for the oob invitation that the didexchange service receives. 53 oobMsgType = "oob-invitation" 54 routerConnsMetadataKey = "routerConnections" 55 ) 56 57 const ( 58 myNSPrefix = "my" 59 // TODO: https://github.com/hyperledger/aries-framework-go/issues/556 It will not be constant, this namespace 60 // will need to be figured with verification key 61 theirNSPrefix = "their" 62 ) 63 64 // message type to store data for eventing. This is retrieved during callback. 65 type message struct { 66 Msg service.DIDCommMsgMap 67 ThreadID string 68 Options *options 69 NextStateName string 70 ConnRecord *connection.Record 71 // err is used to determine whether callback was stopped 72 // e.g the user received an action event and executes Stop(err) function 73 // in that case `err` is equal to `err` which was passing to Stop function 74 err error 75 } 76 77 // provider contains dependencies for the DID exchange protocol and is typically created by using aries.Context(). 78 type provider interface { 79 OutboundDispatcher() dispatcher.Outbound 80 StorageProvider() storage.Provider 81 ProtocolStateStorageProvider() storage.Provider 82 DIDConnectionStore() didstore.ConnectionStore 83 Crypto() crypto.Crypto 84 KMS() kms.KeyManager 85 VDRegistry() vdrapi.Registry 86 Service(id string) (interface{}, error) 87 KeyType() kms.KeyType 88 KeyAgreementType() kms.KeyType 89 MediaTypeProfiles() []string 90 } 91 92 // stateMachineMsg is an internal struct used to pass data to state machine. 93 type stateMachineMsg struct { 94 service.DIDCommMsg 95 connRecord *connection.Record 96 options *options 97 } 98 99 // Service for DID exchange protocol. 100 type Service struct { 101 service.Action 102 service.Message 103 ctx *context 104 callbackChannel chan *message 105 connectionRecorder *connection.Recorder 106 connectionStore didstore.ConnectionStore 107 initialized bool 108 } 109 110 type context struct { 111 outboundDispatcher dispatcher.Outbound 112 crypto crypto.Crypto 113 kms kms.KeyManager 114 connectionRecorder *connection.Recorder 115 connectionStore didstore.ConnectionStore 116 vdRegistry vdrapi.Registry 117 routeSvc mediator.ProtocolService 118 doACAPyInterop bool 119 keyType kms.KeyType 120 keyAgreementType kms.KeyType 121 mediaTypeProfiles []string 122 } 123 124 // opts are used to provide client properties to DID Exchange service. 125 type opts interface { 126 // PublicDID allows for setting public DID 127 PublicDID() string 128 129 // Label allows for setting label 130 Label() string 131 132 // RouterConnections allows for setting router connections 133 RouterConnections() []string 134 } 135 136 // New return didexchange service. 137 func New(prov provider) (*Service, error) { 138 svc := Service{} 139 140 err := svc.Initialize(prov) 141 if err != nil { 142 return nil, err 143 } 144 145 return &svc, nil 146 } 147 148 // Initialize initializes the Service. If Initialize succeeds, any further call is a no-op. 149 func (s *Service) Initialize(p interface{}) error { // nolint: funlen 150 if s.initialized { 151 return nil 152 } 153 154 prov, ok := p.(provider) 155 if !ok { 156 return fmt.Errorf("expected provider of type `%T`, got type `%T`", provider(nil), p) 157 } 158 159 connRecorder, err := connection.NewRecorder(prov) 160 if err != nil { 161 return fmt.Errorf("failed to initialize connection recorder: %w", err) 162 } 163 164 routeSvcBase, err := prov.Service(mediator.Coordination) 165 if err != nil { 166 return err 167 } 168 169 routeSvc, ok := routeSvcBase.(mediator.ProtocolService) 170 if !ok { 171 return errors.New("cast service to Route Service failed") 172 } 173 174 const callbackChannelSize = 10 175 176 keyType := prov.KeyType() 177 if keyType == "" { 178 keyType = kms.ED25519Type 179 } 180 181 keyAgreementType := prov.KeyAgreementType() 182 if keyAgreementType == "" { 183 keyAgreementType = kms.X25519ECDHKWType 184 } 185 186 mediaTypeProfiles := prov.MediaTypeProfiles() 187 if len(mediaTypeProfiles) == 0 { 188 mediaTypeProfiles = []string{transport.MediaTypeAIP2RFC0019Profile} 189 } 190 191 s.ctx = &context{ 192 outboundDispatcher: prov.OutboundDispatcher(), 193 crypto: prov.Crypto(), 194 kms: prov.KMS(), 195 vdRegistry: prov.VDRegistry(), 196 connectionRecorder: connRecorder, 197 connectionStore: prov.DIDConnectionStore(), 198 routeSvc: routeSvc, 199 doACAPyInterop: doACAPyInterop, 200 keyType: keyType, 201 keyAgreementType: keyAgreementType, 202 mediaTypeProfiles: mediaTypeProfiles, 203 } 204 205 // TODO channel size - https://github.com/hyperledger/aries-framework-go/issues/246 206 s.callbackChannel = make(chan *message, callbackChannelSize) 207 s.connectionRecorder = connRecorder 208 s.connectionStore = prov.DIDConnectionStore() 209 210 // start the listener 211 go s.startInternalListener() 212 213 s.initialized = true 214 215 return nil 216 } 217 218 func retrievingRouterConnections(msg service.DIDCommMsg) []string { 219 raw, found := msg.Metadata()[routerConnsMetadataKey] 220 if !found { 221 return nil 222 } 223 224 connections, ok := raw.([]string) 225 if !ok { 226 return nil 227 } 228 229 return connections 230 } 231 232 // HandleInbound handles inbound didexchange messages. 233 func (s *Service) HandleInbound(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) { 234 logger.Debugf("receive inbound message : %s", msg) 235 236 // fetch the thread id 237 thID, err := msg.ThreadID() 238 if err != nil { 239 return "", err 240 } 241 242 // valid state transition and get the next state 243 next, err := s.nextState(msg.Type(), thID) 244 if err != nil { 245 return "", fmt.Errorf("handle inbound - next state : %w", err) 246 } 247 248 // connection record 249 connRecord, err := s.connectionRecord(msg) 250 if err != nil { 251 return "", fmt.Errorf("failed to fetch connection record : %w", err) 252 } 253 254 logger.Debugf("connection record: %+v", connRecord) 255 256 internalMsg := &message{ 257 Options: &options{routerConnections: retrievingRouterConnections(msg)}, 258 Msg: msg.Clone(), 259 ThreadID: thID, 260 NextStateName: next.Name(), 261 ConnRecord: connRecord, 262 } 263 264 go func(msg *message, aEvent chan<- service.DIDCommAction) { 265 if err = s.handle(msg, aEvent); err != nil { 266 logutil.LogError(logger, DIDExchange, "processMessage", err.Error(), 267 logutil.CreateKeyValueString("msgType", msg.Msg.Type()), 268 logutil.CreateKeyValueString("msgID", msg.Msg.ID()), 269 logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID)) 270 } 271 272 logutil.LogDebug(logger, DIDExchange, "processMessage", "success", 273 logutil.CreateKeyValueString("msgType", msg.Msg.Type()), 274 logutil.CreateKeyValueString("msgID", msg.Msg.ID()), 275 logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID)) 276 }(internalMsg, s.ActionEvent()) 277 278 logutil.LogDebug(logger, DIDExchange, "handleInbound", "success", 279 logutil.CreateKeyValueString("msgType", msg.Type()), 280 logutil.CreateKeyValueString("msgID", msg.ID()), 281 logutil.CreateKeyValueString("connectionID", internalMsg.ConnRecord.ConnectionID)) 282 283 return connRecord.ConnectionID, nil 284 } 285 286 // Name return service name. 287 func (s *Service) Name() string { 288 return DIDExchange 289 } 290 291 func findNamespace(msgType string) string { 292 namespace := theirNSPrefix 293 if msgType == InvitationMsgType || msgType == ResponseMsgType || msgType == oobMsgType { 294 namespace = myNSPrefix 295 } 296 297 return namespace 298 } 299 300 // Accept msg checks the msg type. 301 func (s *Service) Accept(msgType string) bool { 302 return msgType == InvitationMsgType || 303 msgType == RequestMsgType || 304 msgType == ResponseMsgType || 305 msgType == AckMsgType || 306 msgType == CompleteMsgType 307 } 308 309 // HandleOutbound handles outbound didexchange messages. 310 func (s *Service) HandleOutbound(_ service.DIDCommMsg, _, _ string) (string, error) { 311 return "", errors.New("not implemented") 312 } 313 314 func (s *Service) nextState(msgType, thID string) (state, error) { 315 logger.Debugf("msgType=%s thID=%s", msgType, thID) 316 317 nsThID, err := connection.CreateNamespaceKey(findNamespace(msgType), thID) 318 if err != nil { 319 return nil, err 320 } 321 322 current, err := s.currentState(nsThID) 323 if err != nil { 324 return nil, err 325 } 326 327 logger.Debugf("retrieved current state [%s] using nsThID [%s]", current.Name(), nsThID) 328 329 next, err := stateFromMsgType(msgType) 330 if err != nil { 331 return nil, err 332 } 333 334 logger.Debugf("check if current state [%s] can transition to [%s]", current.Name(), next.Name()) 335 336 if !current.CanTransitionTo(next) { 337 return nil, fmt.Errorf("invalid state transition: %s -> %s", current.Name(), next.Name()) 338 } 339 340 return next, nil 341 } 342 343 func (s *Service) handle(msg *message, aEvent chan<- service.DIDCommAction) error { //nolint:funlen,gocyclo 344 logger.Debugf("handling msg: %+v", msg) 345 346 next, err := stateFromName(msg.NextStateName) 347 if err != nil { 348 return fmt.Errorf("invalid state name: %w", err) 349 } 350 351 for !isNoOp(next) { 352 s.sendMsgEvents(&service.StateMsg{ 353 ProtocolName: DIDExchange, 354 Type: service.PreState, 355 Msg: msg.Msg.Clone(), 356 StateID: next.Name(), 357 Properties: createEventProperties(msg.ConnRecord.ConnectionID, msg.ConnRecord.InvitationID), 358 }) 359 logger.Debugf("sent pre event for state %s", next.Name()) 360 361 var ( 362 action stateAction 363 followup state 364 connectionRecord *connection.Record 365 ) 366 367 connectionRecord, followup, action, err = next.ExecuteInbound( 368 &stateMachineMsg{ 369 DIDCommMsg: msg.Msg, 370 connRecord: msg.ConnRecord, 371 options: msg.Options, 372 }, 373 msg.ThreadID, 374 s.ctx) 375 376 if err != nil { 377 return fmt.Errorf("failed to execute state '%s': %w", next.Name(), err) 378 } 379 380 connectionRecord.State = next.Name() 381 logger.Debugf("finished execute state: %s", next.Name()) 382 383 if err = s.update(msg.Msg.Type(), connectionRecord); err != nil { 384 return fmt.Errorf("failed to persist state '%s': %w", next.Name(), err) 385 } 386 387 if connectionRecord.State == StateIDCompleted { 388 err = s.connectionStore.SaveDIDByResolving(connectionRecord.TheirDID, connectionRecord.RecipientKeys...) 389 if err != nil { 390 return fmt.Errorf("save theirDID: %w", err) 391 } 392 } 393 394 if err = action(); err != nil { 395 return fmt.Errorf("failed to execute state action '%s': %w", next.Name(), err) 396 } 397 398 logger.Debugf("finish execute state action: '%s'", next.Name()) 399 400 prev := next 401 next = followup 402 haltExecution := false 403 404 // trigger action event based on message type for inbound messages 405 if msg.Msg.Type() != oobMsgType && canTriggerActionEvents(connectionRecord.State, connectionRecord.Namespace) { 406 logger.Debugf("action event triggered for msg type: %s", msg.Msg.Type()) 407 408 msg.NextStateName = next.Name() 409 if err = s.sendActionEvent(msg, aEvent); err != nil { 410 return fmt.Errorf("handle inbound: %w", err) 411 } 412 413 haltExecution = true 414 } 415 416 s.sendMsgEvents(&service.StateMsg{ 417 ProtocolName: DIDExchange, 418 Type: service.PostState, 419 Msg: msg.Msg.Clone(), 420 StateID: prev.Name(), 421 Properties: createEventProperties(connectionRecord.ConnectionID, connectionRecord.InvitationID), 422 }) 423 logger.Debugf("sent post event for state %s", prev.Name()) 424 425 if haltExecution { 426 logger.Debugf("halted execution before state=%s", msg.NextStateName) 427 428 break 429 } 430 } 431 432 return nil 433 } 434 435 func (s *Service) handleWithoutAction(msg *message) error { 436 return s.handle(msg, nil) 437 } 438 439 func createEventProperties(connectionID, invitationID string) *didExchangeEvent { 440 return &didExchangeEvent{ 441 connectionID: connectionID, 442 invitationID: invitationID, 443 } 444 } 445 446 func createErrorEventProperties(connectionID, invitationID string, err error) *didExchangeEventError { 447 props := createEventProperties(connectionID, invitationID) 448 449 return &didExchangeEventError{ 450 err: err, 451 didExchangeEvent: *props, 452 } 453 } 454 455 // sendActionEvent triggers the action event. This function stores the state of current processing and passes a callback 456 // function in the event message. 457 func (s *Service) sendActionEvent(internalMsg *message, aEvent chan<- service.DIDCommAction) error { 458 // save data to support AcceptExchangeRequest APIs (when client will not be able to invoke the callback function) 459 err := s.storeEventProtocolStateData(internalMsg) 460 if err != nil { 461 return fmt.Errorf("send action event : %w", err) 462 } 463 464 if aEvent != nil { 465 // trigger action event 466 aEvent <- service.DIDCommAction{ 467 ProtocolName: DIDExchange, 468 Message: internalMsg.Msg.Clone(), 469 Continue: func(args interface{}) { 470 switch v := args.(type) { 471 case opts: 472 internalMsg.Options = &options{ 473 publicDID: v.PublicDID(), 474 label: v.Label(), 475 routerConnections: v.RouterConnections(), 476 } 477 default: 478 // nothing to do 479 } 480 481 s.processCallback(internalMsg) 482 }, 483 Stop: func(err error) { 484 // sets an error to the message 485 internalMsg.err = err 486 s.processCallback(internalMsg) 487 }, 488 Properties: createEventProperties(internalMsg.ConnRecord.ConnectionID, internalMsg.ConnRecord.InvitationID), 489 } 490 491 logger.Debugf("dispatched action for msg: %+v", internalMsg.Msg) 492 } 493 494 return nil 495 } 496 497 // sendEvent triggers the message events. 498 func (s *Service) sendMsgEvents(msg *service.StateMsg) { 499 // trigger the message events 500 for _, handler := range s.MsgEvents() { 501 handler <- *msg 502 503 logger.Debugf("sent msg event to handler: %+v", msg) 504 } 505 } 506 507 // startInternalListener listens to messages in gochannel for callback messages from clients. 508 func (s *Service) startInternalListener() { 509 for msg := range s.callbackChannel { 510 // TODO https://github.com/hyperledger/aries-framework-go/issues/242 - retry logic 511 // if no error - do handle 512 if msg.err == nil { 513 msg.err = s.handleWithoutAction(msg) 514 } 515 516 // no error - continue 517 if msg.err == nil { 518 continue 519 } 520 521 if err := s.abandon(msg.ThreadID, msg.Msg, msg.err); err != nil { 522 logger.Errorf("process callback : %s", err) 523 } 524 } 525 } 526 527 // AcceptInvitation accepts/approves connection invitation. 528 func (s *Service) AcceptInvitation(connectionID, publicDID, label string, routerConnections []string) error { 529 return s.accept(connectionID, publicDID, label, StateIDInvited, 530 "accept exchange invitation", routerConnections) 531 } 532 533 // AcceptExchangeRequest accepts/approves connection request. 534 func (s *Service) AcceptExchangeRequest(connectionID, publicDID, label string, routerConnections []string) error { 535 return s.accept(connectionID, publicDID, label, StateIDRequested, 536 "accept exchange request", routerConnections) 537 } 538 539 // RespondTo this inbound invitation and return with the new connection record's ID. 540 func (s *Service) RespondTo(i *OOBInvitation, routerConnections []string) (string, error) { 541 i.Type = oobMsgType 542 543 msg := service.NewDIDCommMsgMap(i) 544 msg.Metadata()[routerConnsMetadataKey] = routerConnections 545 546 return s.HandleInbound(msg, service.EmptyDIDCommContext()) 547 } 548 549 // SaveInvitation saves this invitation created by you. 550 func (s *Service) SaveInvitation(i *OOBInvitation) error { 551 i.Type = oobMsgType 552 553 err := s.connectionRecorder.SaveInvitation(i.ThreadID, i) 554 if err != nil { 555 return fmt.Errorf("failed to save oob invitation : %w", err) 556 } 557 558 logger.Debugf("saved invitation: %+v", i) 559 560 return nil 561 } 562 563 func (s *Service) accept(connectionID, publicDID, label, stateID, errMsg string, routerConnections []string) error { 564 msg, err := s.getEventProtocolStateData(connectionID) 565 if err != nil { 566 return fmt.Errorf("failed to accept invitation for connectionID=%s : %s : %w", connectionID, errMsg, err) 567 } 568 569 connRecord, err := s.connectionRecorder.GetConnectionRecord(connectionID) 570 if err != nil { 571 return fmt.Errorf("%s : %w", errMsg, err) 572 } 573 574 if connRecord.State != stateID { 575 return fmt.Errorf("current state (%s) is different from "+ 576 "expected state (%s)", connRecord.State, stateID) 577 } 578 579 msg.Options = &options{publicDID: publicDID, label: label, routerConnections: routerConnections} 580 581 return s.handleWithoutAction(msg) 582 } 583 584 func (s *Service) storeEventProtocolStateData(msg *message) error { 585 bytes, err := json.Marshal(msg) 586 if err != nil { 587 return fmt.Errorf("store protocol state data : %w", err) 588 } 589 590 return s.connectionRecorder.SaveEvent(msg.ConnRecord.ConnectionID, bytes) 591 } 592 593 func (s *Service) getEventProtocolStateData(connectionID string) (*message, error) { 594 val, err := s.connectionRecorder.GetEvent(connectionID) 595 if err != nil { 596 return nil, fmt.Errorf("get protocol state data : %w", err) 597 } 598 599 msg := &message{} 600 601 err = json.Unmarshal(val, msg) 602 if err != nil { 603 return nil, fmt.Errorf("get protocol state data : %w", err) 604 } 605 606 return msg, nil 607 } 608 609 // abandon updates the state to abandoned and trigger failure event. 610 func (s *Service) abandon(thID string, msg service.DIDCommMsg, processErr error) error { 611 // update the state to abandoned 612 nsThID, err := connection.CreateNamespaceKey(findNamespace(msg.Type()), thID) 613 if err != nil { 614 return err 615 } 616 617 connRec, err := s.connectionRecorder.GetConnectionRecordByNSThreadID(nsThID) 618 if err != nil { 619 return fmt.Errorf("unable to update the state to abandoned: %w", err) 620 } 621 622 connRec.State = (&abandoned{}).Name() 623 624 err = s.update(msg.Type(), connRec) 625 if err != nil { 626 return fmt.Errorf("unable to update the state to abandoned: %w", err) 627 } 628 629 // send the message event 630 s.sendMsgEvents(&service.StateMsg{ 631 ProtocolName: DIDExchange, 632 Type: service.PostState, 633 Msg: msg, 634 StateID: StateIDAbandoned, 635 Properties: createErrorEventProperties(connRec.ConnectionID, "", processErr), 636 }) 637 638 return nil 639 } 640 641 func (s *Service) processCallback(msg *message) { 642 // pass the callback data to internal channel. This is created to unblock consumer go routine and wrap the callback 643 // channel internally. 644 s.callbackChannel <- msg 645 } 646 647 func isNoOp(s state) bool { 648 _, ok := s.(*noOp) 649 return ok 650 } 651 652 func (s *Service) currentState(nsThID string) (state, error) { 653 connRec, err := s.connectionRecorder.GetConnectionRecordByNSThreadID(nsThID) 654 if err != nil { 655 if errors.Is(err, storage.ErrDataNotFound) { 656 return &null{}, nil 657 } 658 659 return nil, fmt.Errorf("cannot fetch state from store: thID=%s err=%w", nsThID, err) 660 } 661 662 return stateFromName(connRec.State) 663 } 664 665 func (s *Service) update(msgType string, record *connection.Record) error { 666 if (msgType == RequestMsgType && record.State == StateIDRequested) || 667 (msgType == InvitationMsgType && record.State == StateIDInvited) || 668 (msgType == oobMsgType && record.State == StateIDInvited) { 669 return s.connectionRecorder.SaveConnectionRecordWithMappings(record) 670 } 671 672 return s.connectionRecorder.SaveConnectionRecord(record) 673 } 674 675 // CreateConnection saves the record to the connection store and maps TheirDID to their recipient keys in 676 // the did connection store. 677 func (s *Service) CreateConnection(record *connection.Record, theirDID *did.Doc) error { 678 logger.Debugf("creating connection using record [%+v] and theirDID [%+v]", record, theirDID) 679 680 didMethod, err := vdr.GetDidMethod(theirDID.ID) 681 if err != nil { 682 return err 683 } 684 685 _, err = s.ctx.vdRegistry.Create(didMethod, theirDID, vdrapi.WithOption("store", true)) 686 if err != nil { 687 return fmt.Errorf("vdr failed to store theirDID : %w", err) 688 } 689 690 err = s.connectionStore.SaveDIDFromDoc(theirDID) 691 if err != nil { 692 return fmt.Errorf("failed to save theirDID to the did.ConnectionStore: %w", err) 693 } 694 695 err = s.connectionStore.SaveDIDByResolving(record.MyDID) 696 if err != nil { 697 return fmt.Errorf("failed to save myDID to the did.ConnectionStore: %w", err) 698 } 699 700 if isDIDCommV2(record.MediaTypeProfiles) { 701 record.DIDCommVersion = service.V2 702 } else { 703 record.DIDCommVersion = service.V1 704 } 705 706 return s.connectionRecorder.SaveConnectionRecord(record) 707 } 708 709 func (s *Service) connectionRecord(msg service.DIDCommMsg) (*connection.Record, error) { 710 switch msg.Type() { 711 case oobMsgType: 712 return s.oobInvitationMsgRecord(msg) 713 case InvitationMsgType: 714 return s.invitationMsgRecord(msg) 715 case RequestMsgType: 716 return s.requestMsgRecord(msg) 717 case ResponseMsgType: 718 return s.responseMsgRecord(msg) 719 case AckMsgType, CompleteMsgType: 720 return s.fetchConnectionRecord(theirNSPrefix, msg) 721 } 722 723 return nil, errors.New("invalid message type") 724 } 725 726 //nolint:funlen 727 func (s *Service) oobInvitationMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) { 728 thID, err := msg.ThreadID() 729 if err != nil { 730 return nil, fmt.Errorf("failed to read the oobinvitation threadID : %w", err) 731 } 732 733 var oobInvitation OOBInvitation 734 735 err = msg.Decode(&oobInvitation) 736 if err != nil { 737 return nil, fmt.Errorf("failed to decode the oob invitation : %w", err) 738 } 739 740 svc, err := s.ctx.getServiceBlock(&oobInvitation) 741 if err != nil { 742 return nil, fmt.Errorf("failed to get the did service block from oob invitation : %w", err) 743 } 744 745 uri, err := svc.ServiceEndpoint.URI() 746 if err != nil { 747 logger.Debugf("service DIDComm V1 without ServiceEndpoint URI: %w, skipping it", err) 748 } 749 750 var connRecord *connection.Record 751 752 if accept, err := svc.ServiceEndpoint.Accept(); err == nil && isDIDCommV2(accept) { 753 connRecord = &connection.Record{ 754 ConnectionID: generateRandomID(), 755 ThreadID: thID, 756 ParentThreadID: oobInvitation.ThreadID, 757 State: stateNameNull, 758 InvitationID: oobInvitation.ID, 759 ServiceEndPoint: svc.ServiceEndpoint, 760 RecipientKeys: svc.RecipientKeys, // TODO: recipient keys should be 'theirs' not 'mine'. 761 TheirLabel: oobInvitation.TheirLabel, 762 Namespace: findNamespace(msg.Type()), 763 DIDCommVersion: service.V2, 764 } 765 } else { 766 connRecord = &connection.Record{ 767 ConnectionID: generateRandomID(), 768 ThreadID: thID, 769 ParentThreadID: oobInvitation.ThreadID, 770 State: stateNameNull, 771 InvitationID: oobInvitation.ID, 772 ServiceEndPoint: model.NewDIDCommV1Endpoint(uri), 773 RecipientKeys: svc.RecipientKeys, // TODO: recipient keys should be 'theirs' not 'mine'. 774 TheirLabel: oobInvitation.TheirLabel, 775 Namespace: findNamespace(msg.Type()), 776 MediaTypeProfiles: svc.Accept, 777 DIDCommVersion: service.V1, 778 } 779 } 780 781 publicDID, ok := oobInvitation.Target.(string) 782 if ok { 783 connRecord.Implicit = true 784 connRecord.InvitationDID = publicDID 785 } 786 787 if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil { 788 return nil, err 789 } 790 791 return connRecord, nil 792 } 793 794 func (s *Service) invitationMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) { 795 thID, msgErr := msg.ThreadID() 796 if msgErr != nil { 797 return nil, msgErr 798 } 799 800 invitation := &Invitation{} 801 802 err := msg.Decode(invitation) 803 if err != nil { 804 return nil, err 805 } 806 807 recKey, err := s.ctx.getInvitationRecipientKey(invitation) 808 if err != nil { 809 return nil, err 810 } 811 812 connRecord := &connection.Record{ 813 ConnectionID: generateRandomID(), 814 ThreadID: thID, 815 State: stateNameNull, 816 InvitationID: invitation.ID, 817 InvitationDID: invitation.DID, 818 ServiceEndPoint: model.NewDIDCommV1Endpoint(invitation.ServiceEndpoint), 819 RecipientKeys: []string{recKey}, 820 TheirLabel: invitation.Label, 821 Namespace: findNamespace(msg.Type()), 822 DIDCommVersion: service.V1, 823 } 824 825 if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil { 826 return nil, err 827 } 828 829 return connRecord, nil 830 } 831 832 // nolint:gomnd 833 func pad(b64 string) string { 834 mod := len(b64) % 4 835 if mod <= 1 { 836 return b64 837 } 838 839 return b64 + strings.Repeat("=", 4-mod) 840 } 841 842 func getRequestConnection(r *Request) (*Connection, error) { 843 if r.DocAttach == nil { 844 return nil, fmt.Errorf("missing did_doc~attach from request") 845 } 846 847 docData, err := r.DocAttach.Data.Fetch() 848 if err != nil { 849 return nil, fmt.Errorf("failed to parse base64 attachment data: %w", err) 850 } 851 852 doc, err := did.ParseDocument(docData) 853 if err != nil { 854 logger.Errorf("doc bytes: '%s'", string(docData)) 855 return nil, fmt.Errorf("failed to parse did document: %w", err) 856 } 857 858 return &Connection{ 859 DID: r.DID, 860 DIDDoc: doc, 861 }, nil 862 } 863 864 func (s *Service) requestMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) { 865 request := Request{} 866 867 err := msg.Decode(&request) 868 if err != nil { 869 return nil, fmt.Errorf("unmarshalling failed: %w", err) 870 } 871 872 invitationID := msg.ParentThreadID() 873 if invitationID == "" { 874 return nil, fmt.Errorf("missing parent thread ID on didexchange request with @id=%s", request.ID) 875 } 876 877 connRecord := &connection.Record{ 878 TheirLabel: request.Label, 879 ConnectionID: generateRandomID(), 880 ThreadID: request.ID, 881 State: stateNameNull, 882 InvitationID: invitationID, 883 Namespace: theirNSPrefix, 884 DIDCommVersion: service.V1, 885 } 886 887 connRecord.TheirDID = request.DID 888 889 // ACA-Py Interop: https://github.com/hyperledger/aries-cloudagent-python/issues/1048 890 if !strings.HasPrefix(connRecord.TheirDID, "did") { 891 connRecord.TheirDID = "did:peer:" + connRecord.TheirDID 892 } 893 894 if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil { 895 return nil, err 896 } 897 898 return connRecord, nil 899 } 900 901 func (s *Service) responseMsgRecord(payload service.DIDCommMsg) (*connection.Record, error) { 902 return s.fetchConnectionRecord(myNSPrefix, payload) 903 } 904 905 func (s *Service) fetchConnectionRecord(nsPrefix string, payload service.DIDCommMsg) (*connection.Record, error) { 906 msg := &struct { 907 Thread decorator.Thread `json:"~thread,omitempty"` 908 }{} 909 910 err := payload.Decode(msg) 911 if err != nil { 912 return nil, err 913 } 914 915 key, err := connection.CreateNamespaceKey(nsPrefix, msg.Thread.ID) 916 if err != nil { 917 return nil, err 918 } 919 920 return s.connectionRecorder.GetConnectionRecordByNSThreadID(key) 921 } 922 923 func generateRandomID() string { 924 return uuid.New().String() 925 } 926 927 // canTriggerActionEvents true based on role and state. 928 // 1. Role is invitee and state is invited. 929 // 2. Role is inviter and state is requested. 930 func canTriggerActionEvents(stateID, ns string) bool { 931 return (stateID == StateIDInvited && ns == myNSPrefix) || (stateID == StateIDRequested && ns == theirNSPrefix) 932 } 933 934 type options struct { 935 publicDID string 936 routerConnections []string 937 label string 938 } 939 940 // CreateImplicitInvitation creates implicit invitation. Inviter DID is required, invitee DID is optional. 941 // If invitee DID is not provided new peer DID will be created for implicit invitation exchange request. 942 //nolint:funlen 943 func (s *Service) CreateImplicitInvitation(inviterLabel, inviterDID, 944 inviteeLabel, inviteeDID string, routerConnections []string) (string, error) { 945 logger.Debugf("implicit invitation requested inviterDID[%s] inviteeDID[%s]", inviterDID, inviteeDID) 946 947 docResolution, err := s.ctx.vdRegistry.Resolve(inviterDID) 948 if err != nil { 949 return "", fmt.Errorf("resolve public did[%s]: %w", inviterDID, err) 950 } 951 952 dest, err := service.CreateDestination(docResolution.DIDDocument) 953 if err != nil { 954 return "", err 955 } 956 957 thID := generateRandomID() 958 959 var connRecord *connection.Record 960 961 if accept, e := dest.ServiceEndpoint.Accept(); e == nil && isDIDCommV2(accept) { 962 connRecord = &connection.Record{ 963 ConnectionID: generateRandomID(), 964 ThreadID: thID, 965 State: stateNameNull, 966 InvitationDID: inviterDID, 967 Implicit: true, 968 ServiceEndPoint: dest.ServiceEndpoint, 969 RecipientKeys: dest.RecipientKeys, 970 TheirLabel: inviterLabel, 971 Namespace: findNamespace(InvitationMsgType), 972 } 973 } else { 974 connRecord = &connection.Record{ 975 ConnectionID: generateRandomID(), 976 ThreadID: thID, 977 State: stateNameNull, 978 InvitationDID: inviterDID, 979 Implicit: true, 980 ServiceEndPoint: dest.ServiceEndpoint, 981 RecipientKeys: dest.RecipientKeys, 982 RoutingKeys: dest.RoutingKeys, 983 MediaTypeProfiles: dest.MediaTypeProfiles, 984 TheirLabel: inviterLabel, 985 Namespace: findNamespace(InvitationMsgType), 986 } 987 } 988 989 if e := s.connectionRecorder.SaveConnectionRecordWithMappings(connRecord); e != nil { 990 return "", fmt.Errorf("failed to save new connection record for implicit invitation: %w", e) 991 } 992 993 invitation := &Invitation{ 994 ID: uuid.New().String(), 995 Label: inviterLabel, 996 DID: inviterDID, 997 Type: InvitationMsgType, 998 } 999 1000 msg, err := createDIDCommMsg(invitation) 1001 if err != nil { 1002 return "", fmt.Errorf("failed to create DIDCommMsg for implicit invitation: %w", err) 1003 } 1004 1005 next := &requested{} 1006 internalMsg := &message{ 1007 Msg: msg.Clone(), 1008 ThreadID: thID, 1009 NextStateName: next.Name(), 1010 ConnRecord: connRecord, 1011 } 1012 internalMsg.Options = &options{publicDID: inviteeDID, label: inviteeLabel, routerConnections: routerConnections} 1013 1014 go func(msg *message, aEvent chan<- service.DIDCommAction) { 1015 if err = s.handle(msg, aEvent); err != nil { 1016 logger.Errorf("error from handle for implicit invitation: %s", err) 1017 } 1018 }(internalMsg, s.ActionEvent()) 1019 1020 return connRecord.ConnectionID, nil 1021 } 1022 1023 func createDIDCommMsg(invitation *Invitation) (service.DIDCommMsg, error) { 1024 payload, err := json.Marshal(invitation) 1025 if err != nil { 1026 return nil, fmt.Errorf("marshal invitation: %w", err) 1027 } 1028 1029 return service.ParseDIDCommMsgMap(payload) 1030 }