github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/didexchange/service.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package didexchange
     8  
     9  import (
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"strings"
    14  
    15  	"github.com/google/uuid"
    16  
    17  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    18  	"github.com/hyperledger/aries-framework-go/pkg/common/model"
    19  	"github.com/hyperledger/aries-framework-go/pkg/crypto"
    20  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    21  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher"
    22  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    23  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator"
    24  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    25  	"github.com/hyperledger/aries-framework-go/pkg/doc/did"
    26  	vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr"
    27  	"github.com/hyperledger/aries-framework-go/pkg/internal/logutil"
    28  	"github.com/hyperledger/aries-framework-go/pkg/kms"
    29  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    30  	didstore "github.com/hyperledger/aries-framework-go/pkg/store/did"
    31  	"github.com/hyperledger/aries-framework-go/pkg/vdr"
    32  	"github.com/hyperledger/aries-framework-go/spi/storage"
    33  )
    34  
    35  var logger = log.New("aries-framework/did-exchange/service")
    36  
    37  const (
    38  	// DIDExchange did exchange protocol.
    39  	DIDExchange = "didexchange"
    40  	// PIURI is the did-exchange protocol identifier URI.
    41  	PIURI = "https://didcomm.org/didexchange/1.0"
    42  	// InvitationMsgType defines the did-exchange invite message type.
    43  	InvitationMsgType = PIURI + "/invitation"
    44  	// RequestMsgType defines the did-exchange request message type.
    45  	RequestMsgType = PIURI + "/request"
    46  	// ResponseMsgType defines the did-exchange response message type.
    47  	ResponseMsgType = PIURI + "/response"
    48  	// AckMsgType defines the did-exchange ack message type.
    49  	AckMsgType = PIURI + "/ack"
    50  	// CompleteMsgType defines the did-exchange complete message type.
    51  	CompleteMsgType = PIURI + "/complete"
    52  	// oobMsgType is the internal message type for the oob invitation that the didexchange service receives.
    53  	oobMsgType             = "oob-invitation"
    54  	routerConnsMetadataKey = "routerConnections"
    55  )
    56  
    57  const (
    58  	myNSPrefix = "my"
    59  	// TODO: https://github.com/hyperledger/aries-framework-go/issues/556 It will not be constant, this namespace
    60  	//  will need to be figured with verification key
    61  	theirNSPrefix = "their"
    62  )
    63  
    64  // message type to store data for eventing. This is retrieved during callback.
    65  type message struct {
    66  	Msg           service.DIDCommMsgMap
    67  	ThreadID      string
    68  	Options       *options
    69  	NextStateName string
    70  	ConnRecord    *connection.Record
    71  	// err is used to determine whether callback was stopped
    72  	// e.g the user received an action event and executes Stop(err) function
    73  	// in that case `err` is equal to `err` which was passing to Stop function
    74  	err error
    75  }
    76  
    77  // provider contains dependencies for the DID exchange protocol and is typically created by using aries.Context().
    78  type provider interface {
    79  	OutboundDispatcher() dispatcher.Outbound
    80  	StorageProvider() storage.Provider
    81  	ProtocolStateStorageProvider() storage.Provider
    82  	DIDConnectionStore() didstore.ConnectionStore
    83  	Crypto() crypto.Crypto
    84  	KMS() kms.KeyManager
    85  	VDRegistry() vdrapi.Registry
    86  	Service(id string) (interface{}, error)
    87  	KeyType() kms.KeyType
    88  	KeyAgreementType() kms.KeyType
    89  	MediaTypeProfiles() []string
    90  }
    91  
    92  // stateMachineMsg is an internal struct used to pass data to state machine.
    93  type stateMachineMsg struct {
    94  	service.DIDCommMsg
    95  	connRecord *connection.Record
    96  	options    *options
    97  }
    98  
    99  // Service for DID exchange protocol.
   100  type Service struct {
   101  	service.Action
   102  	service.Message
   103  	ctx                *context
   104  	callbackChannel    chan *message
   105  	connectionRecorder *connection.Recorder
   106  	connectionStore    didstore.ConnectionStore
   107  	initialized        bool
   108  }
   109  
   110  type context struct {
   111  	outboundDispatcher dispatcher.Outbound
   112  	crypto             crypto.Crypto
   113  	kms                kms.KeyManager
   114  	connectionRecorder *connection.Recorder
   115  	connectionStore    didstore.ConnectionStore
   116  	vdRegistry         vdrapi.Registry
   117  	routeSvc           mediator.ProtocolService
   118  	doACAPyInterop     bool
   119  	keyType            kms.KeyType
   120  	keyAgreementType   kms.KeyType
   121  	mediaTypeProfiles  []string
   122  }
   123  
   124  // opts are used to provide client properties to DID Exchange service.
   125  type opts interface {
   126  	// PublicDID allows for setting public DID
   127  	PublicDID() string
   128  
   129  	// Label allows for setting label
   130  	Label() string
   131  
   132  	// RouterConnections allows for setting router connections
   133  	RouterConnections() []string
   134  }
   135  
   136  // New return didexchange service.
   137  func New(prov provider) (*Service, error) {
   138  	svc := Service{}
   139  
   140  	err := svc.Initialize(prov)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	return &svc, nil
   146  }
   147  
   148  // Initialize initializes the Service. If Initialize succeeds, any further call is a no-op.
   149  func (s *Service) Initialize(p interface{}) error { // nolint: funlen
   150  	if s.initialized {
   151  		return nil
   152  	}
   153  
   154  	prov, ok := p.(provider)
   155  	if !ok {
   156  		return fmt.Errorf("expected provider of type `%T`, got type `%T`", provider(nil), p)
   157  	}
   158  
   159  	connRecorder, err := connection.NewRecorder(prov)
   160  	if err != nil {
   161  		return fmt.Errorf("failed to initialize connection recorder: %w", err)
   162  	}
   163  
   164  	routeSvcBase, err := prov.Service(mediator.Coordination)
   165  	if err != nil {
   166  		return err
   167  	}
   168  
   169  	routeSvc, ok := routeSvcBase.(mediator.ProtocolService)
   170  	if !ok {
   171  		return errors.New("cast service to Route Service failed")
   172  	}
   173  
   174  	const callbackChannelSize = 10
   175  
   176  	keyType := prov.KeyType()
   177  	if keyType == "" {
   178  		keyType = kms.ED25519Type
   179  	}
   180  
   181  	keyAgreementType := prov.KeyAgreementType()
   182  	if keyAgreementType == "" {
   183  		keyAgreementType = kms.X25519ECDHKWType
   184  	}
   185  
   186  	mediaTypeProfiles := prov.MediaTypeProfiles()
   187  	if len(mediaTypeProfiles) == 0 {
   188  		mediaTypeProfiles = []string{transport.MediaTypeAIP2RFC0019Profile}
   189  	}
   190  
   191  	s.ctx = &context{
   192  		outboundDispatcher: prov.OutboundDispatcher(),
   193  		crypto:             prov.Crypto(),
   194  		kms:                prov.KMS(),
   195  		vdRegistry:         prov.VDRegistry(),
   196  		connectionRecorder: connRecorder,
   197  		connectionStore:    prov.DIDConnectionStore(),
   198  		routeSvc:           routeSvc,
   199  		doACAPyInterop:     doACAPyInterop,
   200  		keyType:            keyType,
   201  		keyAgreementType:   keyAgreementType,
   202  		mediaTypeProfiles:  mediaTypeProfiles,
   203  	}
   204  
   205  	// TODO channel size - https://github.com/hyperledger/aries-framework-go/issues/246
   206  	s.callbackChannel = make(chan *message, callbackChannelSize)
   207  	s.connectionRecorder = connRecorder
   208  	s.connectionStore = prov.DIDConnectionStore()
   209  
   210  	// start the listener
   211  	go s.startInternalListener()
   212  
   213  	s.initialized = true
   214  
   215  	return nil
   216  }
   217  
   218  func retrievingRouterConnections(msg service.DIDCommMsg) []string {
   219  	raw, found := msg.Metadata()[routerConnsMetadataKey]
   220  	if !found {
   221  		return nil
   222  	}
   223  
   224  	connections, ok := raw.([]string)
   225  	if !ok {
   226  		return nil
   227  	}
   228  
   229  	return connections
   230  }
   231  
   232  // HandleInbound handles inbound didexchange messages.
   233  func (s *Service) HandleInbound(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) {
   234  	logger.Debugf("receive inbound message : %s", msg)
   235  
   236  	// fetch the thread id
   237  	thID, err := msg.ThreadID()
   238  	if err != nil {
   239  		return "", err
   240  	}
   241  
   242  	// valid state transition and get the next state
   243  	next, err := s.nextState(msg.Type(), thID)
   244  	if err != nil {
   245  		return "", fmt.Errorf("handle inbound - next state : %w", err)
   246  	}
   247  
   248  	// connection record
   249  	connRecord, err := s.connectionRecord(msg)
   250  	if err != nil {
   251  		return "", fmt.Errorf("failed to fetch connection record : %w", err)
   252  	}
   253  
   254  	logger.Debugf("connection record: %+v", connRecord)
   255  
   256  	internalMsg := &message{
   257  		Options:       &options{routerConnections: retrievingRouterConnections(msg)},
   258  		Msg:           msg.Clone(),
   259  		ThreadID:      thID,
   260  		NextStateName: next.Name(),
   261  		ConnRecord:    connRecord,
   262  	}
   263  
   264  	go func(msg *message, aEvent chan<- service.DIDCommAction) {
   265  		if err = s.handle(msg, aEvent); err != nil {
   266  			logutil.LogError(logger, DIDExchange, "processMessage", err.Error(),
   267  				logutil.CreateKeyValueString("msgType", msg.Msg.Type()),
   268  				logutil.CreateKeyValueString("msgID", msg.Msg.ID()),
   269  				logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID))
   270  		}
   271  
   272  		logutil.LogDebug(logger, DIDExchange, "processMessage", "success",
   273  			logutil.CreateKeyValueString("msgType", msg.Msg.Type()),
   274  			logutil.CreateKeyValueString("msgID", msg.Msg.ID()),
   275  			logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID))
   276  	}(internalMsg, s.ActionEvent())
   277  
   278  	logutil.LogDebug(logger, DIDExchange, "handleInbound", "success",
   279  		logutil.CreateKeyValueString("msgType", msg.Type()),
   280  		logutil.CreateKeyValueString("msgID", msg.ID()),
   281  		logutil.CreateKeyValueString("connectionID", internalMsg.ConnRecord.ConnectionID))
   282  
   283  	return connRecord.ConnectionID, nil
   284  }
   285  
   286  // Name return service name.
   287  func (s *Service) Name() string {
   288  	return DIDExchange
   289  }
   290  
   291  func findNamespace(msgType string) string {
   292  	namespace := theirNSPrefix
   293  	if msgType == InvitationMsgType || msgType == ResponseMsgType || msgType == oobMsgType {
   294  		namespace = myNSPrefix
   295  	}
   296  
   297  	return namespace
   298  }
   299  
   300  // Accept msg checks the msg type.
   301  func (s *Service) Accept(msgType string) bool {
   302  	return msgType == InvitationMsgType ||
   303  		msgType == RequestMsgType ||
   304  		msgType == ResponseMsgType ||
   305  		msgType == AckMsgType ||
   306  		msgType == CompleteMsgType
   307  }
   308  
   309  // HandleOutbound handles outbound didexchange messages.
   310  func (s *Service) HandleOutbound(_ service.DIDCommMsg, _, _ string) (string, error) {
   311  	return "", errors.New("not implemented")
   312  }
   313  
   314  func (s *Service) nextState(msgType, thID string) (state, error) {
   315  	logger.Debugf("msgType=%s thID=%s", msgType, thID)
   316  
   317  	nsThID, err := connection.CreateNamespaceKey(findNamespace(msgType), thID)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	current, err := s.currentState(nsThID)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  
   327  	logger.Debugf("retrieved current state [%s] using nsThID [%s]", current.Name(), nsThID)
   328  
   329  	next, err := stateFromMsgType(msgType)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	logger.Debugf("check if current state [%s] can transition to [%s]", current.Name(), next.Name())
   335  
   336  	if !current.CanTransitionTo(next) {
   337  		return nil, fmt.Errorf("invalid state transition: %s -> %s", current.Name(), next.Name())
   338  	}
   339  
   340  	return next, nil
   341  }
   342  
   343  func (s *Service) handle(msg *message, aEvent chan<- service.DIDCommAction) error { //nolint:funlen,gocyclo
   344  	logger.Debugf("handling msg: %+v", msg)
   345  
   346  	next, err := stateFromName(msg.NextStateName)
   347  	if err != nil {
   348  		return fmt.Errorf("invalid state name: %w", err)
   349  	}
   350  
   351  	for !isNoOp(next) {
   352  		s.sendMsgEvents(&service.StateMsg{
   353  			ProtocolName: DIDExchange,
   354  			Type:         service.PreState,
   355  			Msg:          msg.Msg.Clone(),
   356  			StateID:      next.Name(),
   357  			Properties:   createEventProperties(msg.ConnRecord.ConnectionID, msg.ConnRecord.InvitationID),
   358  		})
   359  		logger.Debugf("sent pre event for state %s", next.Name())
   360  
   361  		var (
   362  			action           stateAction
   363  			followup         state
   364  			connectionRecord *connection.Record
   365  		)
   366  
   367  		connectionRecord, followup, action, err = next.ExecuteInbound(
   368  			&stateMachineMsg{
   369  				DIDCommMsg: msg.Msg,
   370  				connRecord: msg.ConnRecord,
   371  				options:    msg.Options,
   372  			},
   373  			msg.ThreadID,
   374  			s.ctx)
   375  
   376  		if err != nil {
   377  			return fmt.Errorf("failed to execute state '%s': %w", next.Name(), err)
   378  		}
   379  
   380  		connectionRecord.State = next.Name()
   381  		logger.Debugf("finished execute state: %s", next.Name())
   382  
   383  		if err = s.update(msg.Msg.Type(), connectionRecord); err != nil {
   384  			return fmt.Errorf("failed to persist state '%s': %w", next.Name(), err)
   385  		}
   386  
   387  		if connectionRecord.State == StateIDCompleted {
   388  			err = s.connectionStore.SaveDIDByResolving(connectionRecord.TheirDID, connectionRecord.RecipientKeys...)
   389  			if err != nil {
   390  				return fmt.Errorf("save theirDID: %w", err)
   391  			}
   392  		}
   393  
   394  		if err = action(); err != nil {
   395  			return fmt.Errorf("failed to execute state action '%s': %w", next.Name(), err)
   396  		}
   397  
   398  		logger.Debugf("finish execute state action: '%s'", next.Name())
   399  
   400  		prev := next
   401  		next = followup
   402  		haltExecution := false
   403  
   404  		// trigger action event based on message type for inbound messages
   405  		if msg.Msg.Type() != oobMsgType && canTriggerActionEvents(connectionRecord.State, connectionRecord.Namespace) {
   406  			logger.Debugf("action event triggered for msg type: %s", msg.Msg.Type())
   407  
   408  			msg.NextStateName = next.Name()
   409  			if err = s.sendActionEvent(msg, aEvent); err != nil {
   410  				return fmt.Errorf("handle inbound: %w", err)
   411  			}
   412  
   413  			haltExecution = true
   414  		}
   415  
   416  		s.sendMsgEvents(&service.StateMsg{
   417  			ProtocolName: DIDExchange,
   418  			Type:         service.PostState,
   419  			Msg:          msg.Msg.Clone(),
   420  			StateID:      prev.Name(),
   421  			Properties:   createEventProperties(connectionRecord.ConnectionID, connectionRecord.InvitationID),
   422  		})
   423  		logger.Debugf("sent post event for state %s", prev.Name())
   424  
   425  		if haltExecution {
   426  			logger.Debugf("halted execution before state=%s", msg.NextStateName)
   427  
   428  			break
   429  		}
   430  	}
   431  
   432  	return nil
   433  }
   434  
   435  func (s *Service) handleWithoutAction(msg *message) error {
   436  	return s.handle(msg, nil)
   437  }
   438  
   439  func createEventProperties(connectionID, invitationID string) *didExchangeEvent {
   440  	return &didExchangeEvent{
   441  		connectionID: connectionID,
   442  		invitationID: invitationID,
   443  	}
   444  }
   445  
   446  func createErrorEventProperties(connectionID, invitationID string, err error) *didExchangeEventError {
   447  	props := createEventProperties(connectionID, invitationID)
   448  
   449  	return &didExchangeEventError{
   450  		err:              err,
   451  		didExchangeEvent: *props,
   452  	}
   453  }
   454  
   455  // sendActionEvent triggers the action event. This function stores the state of current processing and passes a callback
   456  // function in the event message.
   457  func (s *Service) sendActionEvent(internalMsg *message, aEvent chan<- service.DIDCommAction) error {
   458  	// save data to support AcceptExchangeRequest APIs (when client will not be able to invoke the callback function)
   459  	err := s.storeEventProtocolStateData(internalMsg)
   460  	if err != nil {
   461  		return fmt.Errorf("send action event : %w", err)
   462  	}
   463  
   464  	if aEvent != nil {
   465  		// trigger action event
   466  		aEvent <- service.DIDCommAction{
   467  			ProtocolName: DIDExchange,
   468  			Message:      internalMsg.Msg.Clone(),
   469  			Continue: func(args interface{}) {
   470  				switch v := args.(type) {
   471  				case opts:
   472  					internalMsg.Options = &options{
   473  						publicDID:         v.PublicDID(),
   474  						label:             v.Label(),
   475  						routerConnections: v.RouterConnections(),
   476  					}
   477  				default:
   478  					// nothing to do
   479  				}
   480  
   481  				s.processCallback(internalMsg)
   482  			},
   483  			Stop: func(err error) {
   484  				// sets an error to the message
   485  				internalMsg.err = err
   486  				s.processCallback(internalMsg)
   487  			},
   488  			Properties: createEventProperties(internalMsg.ConnRecord.ConnectionID, internalMsg.ConnRecord.InvitationID),
   489  		}
   490  
   491  		logger.Debugf("dispatched action for msg: %+v", internalMsg.Msg)
   492  	}
   493  
   494  	return nil
   495  }
   496  
   497  // sendEvent triggers the message events.
   498  func (s *Service) sendMsgEvents(msg *service.StateMsg) {
   499  	// trigger the message events
   500  	for _, handler := range s.MsgEvents() {
   501  		handler <- *msg
   502  
   503  		logger.Debugf("sent msg event to handler: %+v", msg)
   504  	}
   505  }
   506  
   507  // startInternalListener listens to messages in gochannel for callback messages from clients.
   508  func (s *Service) startInternalListener() {
   509  	for msg := range s.callbackChannel {
   510  		// TODO https://github.com/hyperledger/aries-framework-go/issues/242 - retry logic
   511  		// if no error - do handle
   512  		if msg.err == nil {
   513  			msg.err = s.handleWithoutAction(msg)
   514  		}
   515  
   516  		// no error - continue
   517  		if msg.err == nil {
   518  			continue
   519  		}
   520  
   521  		if err := s.abandon(msg.ThreadID, msg.Msg, msg.err); err != nil {
   522  			logger.Errorf("process callback : %s", err)
   523  		}
   524  	}
   525  }
   526  
   527  // AcceptInvitation accepts/approves connection invitation.
   528  func (s *Service) AcceptInvitation(connectionID, publicDID, label string, routerConnections []string) error {
   529  	return s.accept(connectionID, publicDID, label, StateIDInvited,
   530  		"accept exchange invitation", routerConnections)
   531  }
   532  
   533  // AcceptExchangeRequest accepts/approves connection request.
   534  func (s *Service) AcceptExchangeRequest(connectionID, publicDID, label string, routerConnections []string) error {
   535  	return s.accept(connectionID, publicDID, label, StateIDRequested,
   536  		"accept exchange request", routerConnections)
   537  }
   538  
   539  // RespondTo this inbound invitation and return with the new connection record's ID.
   540  func (s *Service) RespondTo(i *OOBInvitation, routerConnections []string) (string, error) {
   541  	i.Type = oobMsgType
   542  
   543  	msg := service.NewDIDCommMsgMap(i)
   544  	msg.Metadata()[routerConnsMetadataKey] = routerConnections
   545  
   546  	return s.HandleInbound(msg, service.EmptyDIDCommContext())
   547  }
   548  
   549  // SaveInvitation saves this invitation created by you.
   550  func (s *Service) SaveInvitation(i *OOBInvitation) error {
   551  	i.Type = oobMsgType
   552  
   553  	err := s.connectionRecorder.SaveInvitation(i.ThreadID, i)
   554  	if err != nil {
   555  		return fmt.Errorf("failed to save oob invitation : %w", err)
   556  	}
   557  
   558  	logger.Debugf("saved invitation: %+v", i)
   559  
   560  	return nil
   561  }
   562  
   563  func (s *Service) accept(connectionID, publicDID, label, stateID, errMsg string, routerConnections []string) error {
   564  	msg, err := s.getEventProtocolStateData(connectionID)
   565  	if err != nil {
   566  		return fmt.Errorf("failed to accept invitation for connectionID=%s : %s : %w", connectionID, errMsg, err)
   567  	}
   568  
   569  	connRecord, err := s.connectionRecorder.GetConnectionRecord(connectionID)
   570  	if err != nil {
   571  		return fmt.Errorf("%s : %w", errMsg, err)
   572  	}
   573  
   574  	if connRecord.State != stateID {
   575  		return fmt.Errorf("current state (%s) is different from "+
   576  			"expected state (%s)", connRecord.State, stateID)
   577  	}
   578  
   579  	msg.Options = &options{publicDID: publicDID, label: label, routerConnections: routerConnections}
   580  
   581  	return s.handleWithoutAction(msg)
   582  }
   583  
   584  func (s *Service) storeEventProtocolStateData(msg *message) error {
   585  	bytes, err := json.Marshal(msg)
   586  	if err != nil {
   587  		return fmt.Errorf("store protocol state data : %w", err)
   588  	}
   589  
   590  	return s.connectionRecorder.SaveEvent(msg.ConnRecord.ConnectionID, bytes)
   591  }
   592  
   593  func (s *Service) getEventProtocolStateData(connectionID string) (*message, error) {
   594  	val, err := s.connectionRecorder.GetEvent(connectionID)
   595  	if err != nil {
   596  		return nil, fmt.Errorf("get protocol state data : %w", err)
   597  	}
   598  
   599  	msg := &message{}
   600  
   601  	err = json.Unmarshal(val, msg)
   602  	if err != nil {
   603  		return nil, fmt.Errorf("get protocol state data : %w", err)
   604  	}
   605  
   606  	return msg, nil
   607  }
   608  
   609  // abandon updates the state to abandoned and trigger failure event.
   610  func (s *Service) abandon(thID string, msg service.DIDCommMsg, processErr error) error {
   611  	// update the state to abandoned
   612  	nsThID, err := connection.CreateNamespaceKey(findNamespace(msg.Type()), thID)
   613  	if err != nil {
   614  		return err
   615  	}
   616  
   617  	connRec, err := s.connectionRecorder.GetConnectionRecordByNSThreadID(nsThID)
   618  	if err != nil {
   619  		return fmt.Errorf("unable to update the state to abandoned: %w", err)
   620  	}
   621  
   622  	connRec.State = (&abandoned{}).Name()
   623  
   624  	err = s.update(msg.Type(), connRec)
   625  	if err != nil {
   626  		return fmt.Errorf("unable to update the state to abandoned: %w", err)
   627  	}
   628  
   629  	// send the message event
   630  	s.sendMsgEvents(&service.StateMsg{
   631  		ProtocolName: DIDExchange,
   632  		Type:         service.PostState,
   633  		Msg:          msg,
   634  		StateID:      StateIDAbandoned,
   635  		Properties:   createErrorEventProperties(connRec.ConnectionID, "", processErr),
   636  	})
   637  
   638  	return nil
   639  }
   640  
   641  func (s *Service) processCallback(msg *message) {
   642  	// pass the callback data to internal channel. This is created to unblock consumer go routine and wrap the callback
   643  	// channel internally.
   644  	s.callbackChannel <- msg
   645  }
   646  
   647  func isNoOp(s state) bool {
   648  	_, ok := s.(*noOp)
   649  	return ok
   650  }
   651  
   652  func (s *Service) currentState(nsThID string) (state, error) {
   653  	connRec, err := s.connectionRecorder.GetConnectionRecordByNSThreadID(nsThID)
   654  	if err != nil {
   655  		if errors.Is(err, storage.ErrDataNotFound) {
   656  			return &null{}, nil
   657  		}
   658  
   659  		return nil, fmt.Errorf("cannot fetch state from store: thID=%s err=%w", nsThID, err)
   660  	}
   661  
   662  	return stateFromName(connRec.State)
   663  }
   664  
   665  func (s *Service) update(msgType string, record *connection.Record) error {
   666  	if (msgType == RequestMsgType && record.State == StateIDRequested) ||
   667  		(msgType == InvitationMsgType && record.State == StateIDInvited) ||
   668  		(msgType == oobMsgType && record.State == StateIDInvited) {
   669  		return s.connectionRecorder.SaveConnectionRecordWithMappings(record)
   670  	}
   671  
   672  	return s.connectionRecorder.SaveConnectionRecord(record)
   673  }
   674  
   675  // CreateConnection saves the record to the connection store and maps TheirDID to their recipient keys in
   676  // the did connection store.
   677  func (s *Service) CreateConnection(record *connection.Record, theirDID *did.Doc) error {
   678  	logger.Debugf("creating connection using record [%+v] and theirDID [%+v]", record, theirDID)
   679  
   680  	didMethod, err := vdr.GetDidMethod(theirDID.ID)
   681  	if err != nil {
   682  		return err
   683  	}
   684  
   685  	_, err = s.ctx.vdRegistry.Create(didMethod, theirDID, vdrapi.WithOption("store", true))
   686  	if err != nil {
   687  		return fmt.Errorf("vdr failed to store theirDID : %w", err)
   688  	}
   689  
   690  	err = s.connectionStore.SaveDIDFromDoc(theirDID)
   691  	if err != nil {
   692  		return fmt.Errorf("failed to save theirDID to the did.ConnectionStore: %w", err)
   693  	}
   694  
   695  	err = s.connectionStore.SaveDIDByResolving(record.MyDID)
   696  	if err != nil {
   697  		return fmt.Errorf("failed to save myDID to the did.ConnectionStore: %w", err)
   698  	}
   699  
   700  	if isDIDCommV2(record.MediaTypeProfiles) {
   701  		record.DIDCommVersion = service.V2
   702  	} else {
   703  		record.DIDCommVersion = service.V1
   704  	}
   705  
   706  	return s.connectionRecorder.SaveConnectionRecord(record)
   707  }
   708  
   709  func (s *Service) connectionRecord(msg service.DIDCommMsg) (*connection.Record, error) {
   710  	switch msg.Type() {
   711  	case oobMsgType:
   712  		return s.oobInvitationMsgRecord(msg)
   713  	case InvitationMsgType:
   714  		return s.invitationMsgRecord(msg)
   715  	case RequestMsgType:
   716  		return s.requestMsgRecord(msg)
   717  	case ResponseMsgType:
   718  		return s.responseMsgRecord(msg)
   719  	case AckMsgType, CompleteMsgType:
   720  		return s.fetchConnectionRecord(theirNSPrefix, msg)
   721  	}
   722  
   723  	return nil, errors.New("invalid message type")
   724  }
   725  
   726  //nolint:funlen
   727  func (s *Service) oobInvitationMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) {
   728  	thID, err := msg.ThreadID()
   729  	if err != nil {
   730  		return nil, fmt.Errorf("failed to read the oobinvitation threadID : %w", err)
   731  	}
   732  
   733  	var oobInvitation OOBInvitation
   734  
   735  	err = msg.Decode(&oobInvitation)
   736  	if err != nil {
   737  		return nil, fmt.Errorf("failed to decode the oob invitation : %w", err)
   738  	}
   739  
   740  	svc, err := s.ctx.getServiceBlock(&oobInvitation)
   741  	if err != nil {
   742  		return nil, fmt.Errorf("failed to get the did service block from oob invitation : %w", err)
   743  	}
   744  
   745  	uri, err := svc.ServiceEndpoint.URI()
   746  	if err != nil {
   747  		logger.Debugf("service DIDComm V1 without ServiceEndpoint URI: %w, skipping it", err)
   748  	}
   749  
   750  	var connRecord *connection.Record
   751  
   752  	if accept, err := svc.ServiceEndpoint.Accept(); err == nil && isDIDCommV2(accept) {
   753  		connRecord = &connection.Record{
   754  			ConnectionID:    generateRandomID(),
   755  			ThreadID:        thID,
   756  			ParentThreadID:  oobInvitation.ThreadID,
   757  			State:           stateNameNull,
   758  			InvitationID:    oobInvitation.ID,
   759  			ServiceEndPoint: svc.ServiceEndpoint,
   760  			RecipientKeys:   svc.RecipientKeys, // TODO: recipient keys should be 'theirs' not 'mine'.
   761  			TheirLabel:      oobInvitation.TheirLabel,
   762  			Namespace:       findNamespace(msg.Type()),
   763  			DIDCommVersion:  service.V2,
   764  		}
   765  	} else {
   766  		connRecord = &connection.Record{
   767  			ConnectionID:      generateRandomID(),
   768  			ThreadID:          thID,
   769  			ParentThreadID:    oobInvitation.ThreadID,
   770  			State:             stateNameNull,
   771  			InvitationID:      oobInvitation.ID,
   772  			ServiceEndPoint:   model.NewDIDCommV1Endpoint(uri),
   773  			RecipientKeys:     svc.RecipientKeys, // TODO: recipient keys should be 'theirs' not 'mine'.
   774  			TheirLabel:        oobInvitation.TheirLabel,
   775  			Namespace:         findNamespace(msg.Type()),
   776  			MediaTypeProfiles: svc.Accept,
   777  			DIDCommVersion:    service.V1,
   778  		}
   779  	}
   780  
   781  	publicDID, ok := oobInvitation.Target.(string)
   782  	if ok {
   783  		connRecord.Implicit = true
   784  		connRecord.InvitationDID = publicDID
   785  	}
   786  
   787  	if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil {
   788  		return nil, err
   789  	}
   790  
   791  	return connRecord, nil
   792  }
   793  
   794  func (s *Service) invitationMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) {
   795  	thID, msgErr := msg.ThreadID()
   796  	if msgErr != nil {
   797  		return nil, msgErr
   798  	}
   799  
   800  	invitation := &Invitation{}
   801  
   802  	err := msg.Decode(invitation)
   803  	if err != nil {
   804  		return nil, err
   805  	}
   806  
   807  	recKey, err := s.ctx.getInvitationRecipientKey(invitation)
   808  	if err != nil {
   809  		return nil, err
   810  	}
   811  
   812  	connRecord := &connection.Record{
   813  		ConnectionID:    generateRandomID(),
   814  		ThreadID:        thID,
   815  		State:           stateNameNull,
   816  		InvitationID:    invitation.ID,
   817  		InvitationDID:   invitation.DID,
   818  		ServiceEndPoint: model.NewDIDCommV1Endpoint(invitation.ServiceEndpoint),
   819  		RecipientKeys:   []string{recKey},
   820  		TheirLabel:      invitation.Label,
   821  		Namespace:       findNamespace(msg.Type()),
   822  		DIDCommVersion:  service.V1,
   823  	}
   824  
   825  	if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil {
   826  		return nil, err
   827  	}
   828  
   829  	return connRecord, nil
   830  }
   831  
   832  // nolint:gomnd
   833  func pad(b64 string) string {
   834  	mod := len(b64) % 4
   835  	if mod <= 1 {
   836  		return b64
   837  	}
   838  
   839  	return b64 + strings.Repeat("=", 4-mod)
   840  }
   841  
   842  func getRequestConnection(r *Request) (*Connection, error) {
   843  	if r.DocAttach == nil {
   844  		return nil, fmt.Errorf("missing did_doc~attach from request")
   845  	}
   846  
   847  	docData, err := r.DocAttach.Data.Fetch()
   848  	if err != nil {
   849  		return nil, fmt.Errorf("failed to parse base64 attachment data: %w", err)
   850  	}
   851  
   852  	doc, err := did.ParseDocument(docData)
   853  	if err != nil {
   854  		logger.Errorf("doc bytes: '%s'", string(docData))
   855  		return nil, fmt.Errorf("failed to parse did document: %w", err)
   856  	}
   857  
   858  	return &Connection{
   859  		DID:    r.DID,
   860  		DIDDoc: doc,
   861  	}, nil
   862  }
   863  
   864  func (s *Service) requestMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) {
   865  	request := Request{}
   866  
   867  	err := msg.Decode(&request)
   868  	if err != nil {
   869  		return nil, fmt.Errorf("unmarshalling failed: %w", err)
   870  	}
   871  
   872  	invitationID := msg.ParentThreadID()
   873  	if invitationID == "" {
   874  		return nil, fmt.Errorf("missing parent thread ID on didexchange request with @id=%s", request.ID)
   875  	}
   876  
   877  	connRecord := &connection.Record{
   878  		TheirLabel:     request.Label,
   879  		ConnectionID:   generateRandomID(),
   880  		ThreadID:       request.ID,
   881  		State:          stateNameNull,
   882  		InvitationID:   invitationID,
   883  		Namespace:      theirNSPrefix,
   884  		DIDCommVersion: service.V1,
   885  	}
   886  
   887  	connRecord.TheirDID = request.DID
   888  
   889  	// ACA-Py Interop: https://github.com/hyperledger/aries-cloudagent-python/issues/1048
   890  	if !strings.HasPrefix(connRecord.TheirDID, "did") {
   891  		connRecord.TheirDID = "did:peer:" + connRecord.TheirDID
   892  	}
   893  
   894  	if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil {
   895  		return nil, err
   896  	}
   897  
   898  	return connRecord, nil
   899  }
   900  
   901  func (s *Service) responseMsgRecord(payload service.DIDCommMsg) (*connection.Record, error) {
   902  	return s.fetchConnectionRecord(myNSPrefix, payload)
   903  }
   904  
   905  func (s *Service) fetchConnectionRecord(nsPrefix string, payload service.DIDCommMsg) (*connection.Record, error) {
   906  	msg := &struct {
   907  		Thread decorator.Thread `json:"~thread,omitempty"`
   908  	}{}
   909  
   910  	err := payload.Decode(msg)
   911  	if err != nil {
   912  		return nil, err
   913  	}
   914  
   915  	key, err := connection.CreateNamespaceKey(nsPrefix, msg.Thread.ID)
   916  	if err != nil {
   917  		return nil, err
   918  	}
   919  
   920  	return s.connectionRecorder.GetConnectionRecordByNSThreadID(key)
   921  }
   922  
   923  func generateRandomID() string {
   924  	return uuid.New().String()
   925  }
   926  
   927  // canTriggerActionEvents true based on role and state.
   928  // 1. Role is invitee and state is invited.
   929  // 2. Role is inviter and state is requested.
   930  func canTriggerActionEvents(stateID, ns string) bool {
   931  	return (stateID == StateIDInvited && ns == myNSPrefix) || (stateID == StateIDRequested && ns == theirNSPrefix)
   932  }
   933  
   934  type options struct {
   935  	publicDID         string
   936  	routerConnections []string
   937  	label             string
   938  }
   939  
   940  // CreateImplicitInvitation creates implicit invitation. Inviter DID is required, invitee DID is optional.
   941  // If invitee DID is not provided new peer DID will be created for implicit invitation exchange request.
   942  //nolint:funlen
   943  func (s *Service) CreateImplicitInvitation(inviterLabel, inviterDID,
   944  	inviteeLabel, inviteeDID string, routerConnections []string) (string, error) {
   945  	logger.Debugf("implicit invitation requested inviterDID[%s] inviteeDID[%s]", inviterDID, inviteeDID)
   946  
   947  	docResolution, err := s.ctx.vdRegistry.Resolve(inviterDID)
   948  	if err != nil {
   949  		return "", fmt.Errorf("resolve public did[%s]: %w", inviterDID, err)
   950  	}
   951  
   952  	dest, err := service.CreateDestination(docResolution.DIDDocument)
   953  	if err != nil {
   954  		return "", err
   955  	}
   956  
   957  	thID := generateRandomID()
   958  
   959  	var connRecord *connection.Record
   960  
   961  	if accept, e := dest.ServiceEndpoint.Accept(); e == nil && isDIDCommV2(accept) {
   962  		connRecord = &connection.Record{
   963  			ConnectionID:    generateRandomID(),
   964  			ThreadID:        thID,
   965  			State:           stateNameNull,
   966  			InvitationDID:   inviterDID,
   967  			Implicit:        true,
   968  			ServiceEndPoint: dest.ServiceEndpoint,
   969  			RecipientKeys:   dest.RecipientKeys,
   970  			TheirLabel:      inviterLabel,
   971  			Namespace:       findNamespace(InvitationMsgType),
   972  		}
   973  	} else {
   974  		connRecord = &connection.Record{
   975  			ConnectionID:      generateRandomID(),
   976  			ThreadID:          thID,
   977  			State:             stateNameNull,
   978  			InvitationDID:     inviterDID,
   979  			Implicit:          true,
   980  			ServiceEndPoint:   dest.ServiceEndpoint,
   981  			RecipientKeys:     dest.RecipientKeys,
   982  			RoutingKeys:       dest.RoutingKeys,
   983  			MediaTypeProfiles: dest.MediaTypeProfiles,
   984  			TheirLabel:        inviterLabel,
   985  			Namespace:         findNamespace(InvitationMsgType),
   986  		}
   987  	}
   988  
   989  	if e := s.connectionRecorder.SaveConnectionRecordWithMappings(connRecord); e != nil {
   990  		return "", fmt.Errorf("failed to save new connection record for implicit invitation: %w", e)
   991  	}
   992  
   993  	invitation := &Invitation{
   994  		ID:    uuid.New().String(),
   995  		Label: inviterLabel,
   996  		DID:   inviterDID,
   997  		Type:  InvitationMsgType,
   998  	}
   999  
  1000  	msg, err := createDIDCommMsg(invitation)
  1001  	if err != nil {
  1002  		return "", fmt.Errorf("failed to create DIDCommMsg for implicit invitation: %w", err)
  1003  	}
  1004  
  1005  	next := &requested{}
  1006  	internalMsg := &message{
  1007  		Msg:           msg.Clone(),
  1008  		ThreadID:      thID,
  1009  		NextStateName: next.Name(),
  1010  		ConnRecord:    connRecord,
  1011  	}
  1012  	internalMsg.Options = &options{publicDID: inviteeDID, label: inviteeLabel, routerConnections: routerConnections}
  1013  
  1014  	go func(msg *message, aEvent chan<- service.DIDCommAction) {
  1015  		if err = s.handle(msg, aEvent); err != nil {
  1016  			logger.Errorf("error from handle for implicit invitation: %s", err)
  1017  		}
  1018  	}(internalMsg, s.ActionEvent())
  1019  
  1020  	return connRecord.ConnectionID, nil
  1021  }
  1022  
  1023  func createDIDCommMsg(invitation *Invitation) (service.DIDCommMsg, error) {
  1024  	payload, err := json.Marshal(invitation)
  1025  	if err != nil {
  1026  		return nil, fmt.Errorf("marshal invitation: %w", err)
  1027  	}
  1028  
  1029  	return service.ParseDIDCommMsgMap(payload)
  1030  }