github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/legacyconnection/service.go (about)

     1  /*
     2  Copyright Avast Software. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package legacyconnection
     8  
     9  import (
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  
    14  	"github.com/google/uuid"
    15  
    16  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    17  	"github.com/hyperledger/aries-framework-go/pkg/common/model"
    18  	"github.com/hyperledger/aries-framework-go/pkg/crypto"
    19  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    20  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher"
    21  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    22  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/mediator"
    23  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    24  	"github.com/hyperledger/aries-framework-go/pkg/doc/did"
    25  	vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr"
    26  	"github.com/hyperledger/aries-framework-go/pkg/internal/logutil"
    27  	"github.com/hyperledger/aries-framework-go/pkg/kms"
    28  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    29  	didstore "github.com/hyperledger/aries-framework-go/pkg/store/did"
    30  	"github.com/hyperledger/aries-framework-go/pkg/vdr"
    31  	"github.com/hyperledger/aries-framework-go/spi/storage"
    32  )
    33  
    34  var logger = log.New("aries-framework/legacyconnection/service")
    35  
    36  const (
    37  	// LegacyConnection connection protocol.
    38  	LegacyConnection = "legacyconnection"
    39  	// PIURI is the connection protocol identifier URI.
    40  	PIURI = "https://didcomm.org/connections/1.0"
    41  	// InvitationMsgType defines the legacy-connection invite message type.
    42  	InvitationMsgType = PIURI + "/invitation"
    43  	// RequestMsgType defines the legacy-connection request message type.
    44  	RequestMsgType = PIURI + "/request"
    45  	// ResponseMsgType defines the legacy-connection response message type.
    46  	ResponseMsgType = PIURI + "/response"
    47  	// AckMsgType defines the legacy-connection ack message type.
    48  	AckMsgType             = "https://didcomm.org/notification/1.0/ack"
    49  	routerConnsMetadataKey = "routerConnections"
    50  )
    51  
    52  const (
    53  	myNSPrefix = "my"
    54  	// TODO: https://github.com/hyperledger/aries-framework-go/issues/556 It will not be constant, this namespace
    55  	//  will need to be figured with verification key
    56  	theirNSPrefix = "their"
    57  	// InvitationRecipientKey is map key constant.
    58  	InvitationRecipientKey = "invRecipientKey"
    59  )
    60  
    61  // message type to store data for eventing. This is retrieved during callback.
    62  type message struct {
    63  	Msg           service.DIDCommMsgMap
    64  	ThreadID      string
    65  	Options       *options
    66  	NextStateName string
    67  	ConnRecord    *connection.Record
    68  	// err is used to determine whether callback was stopped
    69  	// e.g the user received an action event and executes Stop(err) function
    70  	// in that case `err` is equal to `err` which was passing to Stop function
    71  	err error
    72  }
    73  
    74  // provider contains dependencies for the Connection protocol and is typically created by using aries.Context().
    75  type provider interface {
    76  	OutboundDispatcher() dispatcher.Outbound
    77  	StorageProvider() storage.Provider
    78  	ProtocolStateStorageProvider() storage.Provider
    79  	DIDConnectionStore() didstore.ConnectionStore
    80  	Crypto() crypto.Crypto
    81  	KMS() kms.KeyManager
    82  	VDRegistry() vdrapi.Registry
    83  	Service(id string) (interface{}, error)
    84  	KeyType() kms.KeyType
    85  	KeyAgreementType() kms.KeyType
    86  	MediaTypeProfiles() []string
    87  }
    88  
    89  // stateMachineMsg is an internal struct used to pass data to state machine.
    90  type stateMachineMsg struct {
    91  	service.DIDCommMsg
    92  	connRecord *connection.Record
    93  	options    *options
    94  }
    95  
    96  type options struct {
    97  	publicDID         string
    98  	routerConnections []string
    99  	label             string
   100  }
   101  
   102  // Service for Connection protocol.
   103  type Service struct {
   104  	service.Action
   105  	service.Message
   106  	ctx                *context
   107  	callbackChannel    chan *message
   108  	connectionRecorder *connection.Recorder
   109  	connectionStore    didstore.ConnectionStore
   110  	initialized        bool
   111  }
   112  
   113  type context struct {
   114  	outboundDispatcher dispatcher.Outbound
   115  	crypto             crypto.Crypto
   116  	kms                kms.KeyManager
   117  	connectionRecorder *connection.Recorder
   118  	connectionStore    didstore.ConnectionStore
   119  	vdRegistry         vdrapi.Registry
   120  	routeSvc           mediator.ProtocolService
   121  	doACAPyInterop     bool
   122  	keyType            kms.KeyType
   123  	keyAgreementType   kms.KeyType
   124  	mediaTypeProfiles  []string
   125  }
   126  
   127  // opts are used to provide client properties to Connection service.
   128  type opts interface {
   129  	// PublicDID allows for setting public DID
   130  	PublicDID() string
   131  
   132  	// Label allows for setting label
   133  	Label() string
   134  
   135  	// RouterConnections allows for setting router connections
   136  	RouterConnections() []string
   137  }
   138  
   139  // New return connection service.
   140  func New(prov provider) (*Service, error) {
   141  	svc := Service{}
   142  
   143  	err := svc.Initialize(prov)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	return &svc, nil
   149  }
   150  
   151  // Initialize initializes the Service. If Initialize succeeds, any further call is a no-op.
   152  func (s *Service) Initialize(p interface{}) error {
   153  	if s.initialized {
   154  		return nil
   155  	}
   156  
   157  	prov, ok := p.(provider)
   158  	if !ok {
   159  		return fmt.Errorf("expected provider of type `%T`, got type `%T`", provider(nil), p)
   160  	}
   161  
   162  	connRecorder, err := connection.NewRecorder(prov)
   163  	if err != nil {
   164  		return fmt.Errorf("failed to initialize connection recorder: %w", err)
   165  	}
   166  
   167  	routeSvcBase, err := prov.Service(mediator.Coordination)
   168  	if err != nil {
   169  		return err
   170  	}
   171  
   172  	routeSvc, ok := routeSvcBase.(mediator.ProtocolService)
   173  	if !ok {
   174  		return errors.New("cast service to Route Service failed")
   175  	}
   176  
   177  	const callbackChannelSize = 10
   178  
   179  	keyType := kms.ED25519Type
   180  
   181  	keyAgreementType := kms.X25519ECDHKWType
   182  
   183  	mediaTypeProfiles := []string{transport.LegacyDIDCommV1Profile}
   184  
   185  	s.ctx = &context{
   186  		outboundDispatcher: prov.OutboundDispatcher(),
   187  		crypto:             prov.Crypto(),
   188  		kms:                prov.KMS(),
   189  		vdRegistry:         prov.VDRegistry(),
   190  		connectionRecorder: connRecorder,
   191  		connectionStore:    prov.DIDConnectionStore(),
   192  		routeSvc:           routeSvc,
   193  		keyType:            keyType,
   194  		keyAgreementType:   keyAgreementType,
   195  		mediaTypeProfiles:  mediaTypeProfiles,
   196  	}
   197  
   198  	// TODO channel size - https://github.com/hyperledger/aries-framework-go/issues/246
   199  	s.callbackChannel = make(chan *message, callbackChannelSize)
   200  	s.connectionRecorder = connRecorder
   201  	s.connectionStore = prov.DIDConnectionStore()
   202  
   203  	// start the listener
   204  	go s.startInternalListener()
   205  
   206  	s.initialized = true
   207  
   208  	return nil
   209  }
   210  
   211  func retrievingRouterConnections(msg service.DIDCommMsg) []string {
   212  	raw, found := msg.Metadata()[routerConnsMetadataKey]
   213  	if !found {
   214  		return nil
   215  	}
   216  
   217  	connections, ok := raw.([]string)
   218  	if !ok {
   219  		return nil
   220  	}
   221  
   222  	return connections
   223  }
   224  
   225  // HandleInbound handles inbound connection messages.
   226  func (s *Service) HandleInbound(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) {
   227  	logger.Debugf("receive inbound message : %s", msg)
   228  
   229  	// fetch the thread id
   230  	thID, err := msg.ThreadID()
   231  	if err != nil {
   232  		return "", err
   233  	}
   234  
   235  	// valid state transition and get the next state
   236  	next, err := s.nextState(msg.Type(), thID)
   237  	if err != nil {
   238  		return "", fmt.Errorf("handle inbound - next state : %w", err)
   239  	}
   240  
   241  	// connection record
   242  	connRecord, err := s.connectionRecord(msg, ctx)
   243  	if err != nil {
   244  		return "", fmt.Errorf("failed to fetch connection record : %w", err)
   245  	}
   246  
   247  	logger.Debugf("connection record: %+v", connRecord)
   248  
   249  	internalMsg := &message{
   250  		Options:       &options{routerConnections: retrievingRouterConnections(msg)},
   251  		Msg:           msg.Clone(),
   252  		ThreadID:      thID,
   253  		NextStateName: next.Name(),
   254  		ConnRecord:    connRecord,
   255  	}
   256  
   257  	go func(msg *message, aEvent chan<- service.DIDCommAction) {
   258  		if err = s.handle(msg, aEvent); err != nil {
   259  			logutil.LogError(logger, LegacyConnection, "processMessage", err.Error(),
   260  				logutil.CreateKeyValueString("msgType", msg.Msg.Type()),
   261  				logutil.CreateKeyValueString("msgID", msg.Msg.ID()),
   262  				logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID))
   263  		}
   264  
   265  		logutil.LogDebug(logger, LegacyConnection, "processMessage", "success",
   266  			logutil.CreateKeyValueString("msgType", msg.Msg.Type()),
   267  			logutil.CreateKeyValueString("msgID", msg.Msg.ID()),
   268  			logutil.CreateKeyValueString("connectionID", msg.ConnRecord.ConnectionID))
   269  	}(internalMsg, s.ActionEvent())
   270  
   271  	logutil.LogDebug(logger, LegacyConnection, "handleInbound", "success",
   272  		logutil.CreateKeyValueString("msgType", msg.Type()),
   273  		logutil.CreateKeyValueString("msgID", msg.ID()),
   274  		logutil.CreateKeyValueString("connectionID", internalMsg.ConnRecord.ConnectionID))
   275  
   276  	return connRecord.ConnectionID, nil
   277  }
   278  
   279  // Name return service name.
   280  func (s *Service) Name() string {
   281  	return LegacyConnection
   282  }
   283  
   284  func findNamespace(msgType string) string {
   285  	namespace := theirNSPrefix
   286  	if msgType == InvitationMsgType || msgType == ResponseMsgType {
   287  		namespace = myNSPrefix
   288  	}
   289  
   290  	return namespace
   291  }
   292  
   293  // Accept msg checks the msg type.
   294  func (s *Service) Accept(msgType string) bool {
   295  	return msgType == InvitationMsgType ||
   296  		msgType == RequestMsgType ||
   297  		msgType == ResponseMsgType ||
   298  		msgType == AckMsgType
   299  }
   300  
   301  // HandleOutbound handles outbound connection messages.
   302  func (s *Service) HandleOutbound(_ service.DIDCommMsg, _, _ string) (string, error) {
   303  	return "", errors.New("not implemented")
   304  }
   305  
   306  func (s *Service) nextState(msgType, thID string) (state, error) {
   307  	logger.Debugf("msgType=%s thID=%s", msgType, thID)
   308  
   309  	nsThID, err := connection.CreateNamespaceKey(findNamespace(msgType), thID)
   310  	if err != nil {
   311  		return nil, err
   312  	}
   313  
   314  	current, err := s.currentState(nsThID)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  
   319  	logger.Debugf("retrieved current state [%s] using nsThID [%s]", current.Name(), nsThID)
   320  
   321  	next, err := stateFromMsgType(msgType)
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  
   326  	logger.Debugf("check if current state [%s] can transition to [%s]", current.Name(), next.Name())
   327  
   328  	if !current.CanTransitionTo(next) {
   329  		return nil, fmt.Errorf("invalid state transition: %s -> %s", current.Name(), next.Name())
   330  	}
   331  
   332  	return next, nil
   333  }
   334  
   335  func (s *Service) handle(msg *message, aEvent chan<- service.DIDCommAction) error { //nolint:funlen,gocyclo
   336  	logger.Debugf("handling msg: %+v", msg)
   337  
   338  	next, err := stateFromName(msg.NextStateName)
   339  	if err != nil {
   340  		return fmt.Errorf("invalid state name: %w", err)
   341  	}
   342  
   343  	for !isNoOp(next) {
   344  		s.sendMsgEvents(&service.StateMsg{
   345  			ProtocolName: LegacyConnection,
   346  			Type:         service.PreState,
   347  			Msg:          msg.Msg.Clone(),
   348  			StateID:      next.Name(),
   349  			Properties:   createEventProperties(msg.ConnRecord.ConnectionID, msg.ConnRecord.InvitationID),
   350  		})
   351  		logger.Debugf("sent pre event for state %s", next.Name())
   352  
   353  		var (
   354  			action           stateAction
   355  			followup         state
   356  			connectionRecord *connection.Record
   357  		)
   358  
   359  		connectionRecord, followup, action, err = next.ExecuteInbound(
   360  			&stateMachineMsg{
   361  				DIDCommMsg: msg.Msg,
   362  				connRecord: msg.ConnRecord,
   363  				options:    msg.Options,
   364  			},
   365  			msg.ThreadID,
   366  			s.ctx)
   367  
   368  		if err != nil {
   369  			return fmt.Errorf("failed to execute state '%s': %w", next.Name(), err)
   370  		}
   371  
   372  		connectionRecord.State = next.Name()
   373  		logger.Debugf("finished execute state: %s", next.Name())
   374  
   375  		if err = s.update(msg.Msg.Type(), connectionRecord); err != nil {
   376  			return fmt.Errorf("failed to persist state '%s': %w", next.Name(), err)
   377  		}
   378  
   379  		if connectionRecord.State == StateIDCompleted {
   380  			err = s.connectionStore.SaveDIDByResolving(connectionRecord.TheirDID, connectionRecord.RecipientKeys...)
   381  			if err != nil {
   382  				return fmt.Errorf("save theirDID: %w", err)
   383  			}
   384  		}
   385  
   386  		if err = action(); err != nil {
   387  			return fmt.Errorf("failed to execute state action '%s': %w", next.Name(), err)
   388  		}
   389  
   390  		logger.Debugf("finish execute state action: '%s'", next.Name())
   391  
   392  		prev := next
   393  		next = followup
   394  		haltExecution := false
   395  
   396  		// trigger action event based on message type for inbound messages
   397  		if canTriggerActionEvents(connectionRecord.State, connectionRecord.Namespace) {
   398  			logger.Debugf("action event triggered for msg type: %s", msg.Msg.Type())
   399  
   400  			msg.NextStateName = next.Name()
   401  			if err = s.sendActionEvent(msg, aEvent); err != nil {
   402  				return fmt.Errorf("handle inbound: %w", err)
   403  			}
   404  
   405  			haltExecution = true
   406  		}
   407  
   408  		s.sendMsgEvents(&service.StateMsg{
   409  			ProtocolName: LegacyConnection,
   410  			Type:         service.PostState,
   411  			Msg:          msg.Msg.Clone(),
   412  			StateID:      prev.Name(),
   413  			Properties:   createEventProperties(connectionRecord.ConnectionID, connectionRecord.InvitationID),
   414  		})
   415  		logger.Debugf("sent post event for state %s", prev.Name())
   416  
   417  		if haltExecution {
   418  			logger.Debugf("halted execution before state=%s", msg.NextStateName)
   419  
   420  			break
   421  		}
   422  	}
   423  
   424  	return nil
   425  }
   426  
   427  func (s *Service) handleWithoutAction(msg *message) error {
   428  	return s.handle(msg, nil)
   429  }
   430  
   431  func createEventProperties(connectionID, invitationID string) *connectionEvent {
   432  	return &connectionEvent{
   433  		connectionID: connectionID,
   434  		invitationID: invitationID,
   435  	}
   436  }
   437  
   438  // sendActionEvent triggers the action event. This function stores the state of current processing and passes a callback
   439  // function in the event message.
   440  func (s *Service) sendActionEvent(internalMsg *message, aEvent chan<- service.DIDCommAction) error {
   441  	// save data to support AcceptConnectionRequest APIs (when client will not be able to invoke the callback function)
   442  	err := s.storeEventProtocolStateData(internalMsg)
   443  	if err != nil {
   444  		return fmt.Errorf("send action event : %w", err)
   445  	}
   446  
   447  	if aEvent != nil {
   448  		// trigger action event
   449  		aEvent <- service.DIDCommAction{
   450  			ProtocolName: LegacyConnection,
   451  			Message:      internalMsg.Msg.Clone(),
   452  			Continue: func(args interface{}) {
   453  				switch v := args.(type) {
   454  				case opts:
   455  					internalMsg.Options = &options{
   456  						publicDID:         v.PublicDID(),
   457  						label:             v.Label(),
   458  						routerConnections: v.RouterConnections(),
   459  					}
   460  				default:
   461  					// nothing to do
   462  				}
   463  
   464  				s.processCallback(internalMsg)
   465  			},
   466  			Stop: func(err error) {
   467  				// sets an error to the message
   468  				internalMsg.err = err
   469  				s.processCallback(internalMsg)
   470  			},
   471  			Properties: createEventProperties(internalMsg.ConnRecord.ConnectionID, internalMsg.ConnRecord.InvitationID),
   472  		}
   473  
   474  		logger.Debugf("dispatched action for msg: %+v", internalMsg.Msg)
   475  	}
   476  
   477  	return nil
   478  }
   479  
   480  // sendEvent triggers the message events.
   481  func (s *Service) sendMsgEvents(msg *service.StateMsg) {
   482  	// trigger the message events
   483  	for _, handler := range s.MsgEvents() {
   484  		handler <- *msg
   485  
   486  		logger.Debugf("sent msg event to handler: %+v", msg)
   487  	}
   488  }
   489  
   490  // startInternalListener listens to messages in gochannel for callback messages from clients.
   491  func (s *Service) startInternalListener() {
   492  	for msg := range s.callbackChannel {
   493  		// TODO https://github.com/hyperledger/aries-framework-go/issues/242 - retry logic
   494  		// if no error - do handle
   495  		if msg.err == nil {
   496  			msg.err = s.handleWithoutAction(msg)
   497  		}
   498  
   499  		// no error - continue
   500  		if msg.err == nil {
   501  			continue
   502  		}
   503  	}
   504  }
   505  
   506  // AcceptInvitation accepts/approves connection invitation.
   507  func (s *Service) AcceptInvitation(connectionID, publicDID, label string, routerConnections []string) error {
   508  	return s.accept(connectionID, publicDID, label, StateIDInvited,
   509  		"accept connection invitation", routerConnections)
   510  }
   511  
   512  // AcceptConnectionRequest accepts/approves connection request.
   513  func (s *Service) AcceptConnectionRequest(connectionID, publicDID, label string, routerConnections []string) error {
   514  	return s.accept(connectionID, publicDID, label, StateIDRequested,
   515  		"accept exchange request", routerConnections)
   516  }
   517  
   518  func (s *Service) accept(connectionID, publicDID, label, stateID, errMsg string, routerConnections []string) error {
   519  	msg, err := s.getEventProtocolStateData(connectionID)
   520  	if err != nil {
   521  		return fmt.Errorf("failed to accept invitation for connectionID=%s : %s : %w", connectionID, errMsg, err)
   522  	}
   523  
   524  	connRecord, err := s.connectionRecorder.GetConnectionRecord(connectionID)
   525  	if err != nil {
   526  		return fmt.Errorf("%s : %w", errMsg, err)
   527  	}
   528  
   529  	if connRecord.State != stateID {
   530  		return fmt.Errorf("current state (%s) is different from "+
   531  			"expected state (%s)", connRecord.State, stateID)
   532  	}
   533  
   534  	msg.Options = &options{publicDID: publicDID, label: label, routerConnections: routerConnections}
   535  
   536  	return s.handleWithoutAction(msg)
   537  }
   538  
   539  func (s *Service) storeEventProtocolStateData(msg *message) error {
   540  	bytes, err := json.Marshal(msg)
   541  	if err != nil {
   542  		return fmt.Errorf("store protocol state data : %w", err)
   543  	}
   544  
   545  	return s.connectionRecorder.SaveEvent(msg.ConnRecord.ConnectionID, bytes)
   546  }
   547  
   548  func (s *Service) getEventProtocolStateData(connectionID string) (*message, error) {
   549  	val, err := s.connectionRecorder.GetEvent(connectionID)
   550  	if err != nil {
   551  		return nil, fmt.Errorf("get protocol state data : %w", err)
   552  	}
   553  
   554  	msg := &message{}
   555  
   556  	err = json.Unmarshal(val, msg)
   557  	if err != nil {
   558  		return nil, fmt.Errorf("get protocol state data : %w", err)
   559  	}
   560  
   561  	return msg, nil
   562  }
   563  
   564  func (s *Service) processCallback(msg *message) {
   565  	// pass the callback data to internal channel. This is created to unblock consumer go routine and wrap the callback
   566  	// channel internally.
   567  	s.callbackChannel <- msg
   568  }
   569  
   570  func isNoOp(s state) bool {
   571  	_, ok := s.(*noOp)
   572  	return ok
   573  }
   574  
   575  func (s *Service) currentState(nsThID string) (state, error) {
   576  	connRec, err := s.connectionRecorder.GetConnectionRecordByNSThreadID(nsThID)
   577  	if err != nil {
   578  		if errors.Is(err, storage.ErrDataNotFound) {
   579  			return &null{}, nil
   580  		}
   581  
   582  		return nil, fmt.Errorf("cannot fetch state from store: thID=%s err=%w", nsThID, err)
   583  	}
   584  
   585  	return stateFromName(connRec.State)
   586  }
   587  
   588  func (s *Service) update(msgType string, record *connection.Record) error {
   589  	if (msgType == RequestMsgType && record.State == StateIDRequested) ||
   590  		(msgType == InvitationMsgType && record.State == StateIDInvited) {
   591  		return s.connectionRecorder.SaveConnectionRecordWithMappings(record)
   592  	}
   593  
   594  	return s.connectionRecorder.SaveConnectionRecord(record)
   595  }
   596  
   597  // CreateConnection saves the record to the connection store and maps TheirDID to their recipient keys in
   598  // the did connection store.
   599  func (s *Service) CreateConnection(record *connection.Record, theirDID *did.Doc) error {
   600  	logger.Debugf("creating connection using record [%+v] and theirDID [%+v]", record, theirDID)
   601  
   602  	didMethod, err := vdr.GetDidMethod(theirDID.ID)
   603  	if err != nil {
   604  		return err
   605  	}
   606  
   607  	_, err = s.ctx.vdRegistry.Create(didMethod, theirDID, vdrapi.WithOption("store", true))
   608  	if err != nil {
   609  		return fmt.Errorf("vdr failed to store theirDID : %w", err)
   610  	}
   611  
   612  	err = s.connectionStore.SaveDIDFromDoc(theirDID)
   613  	if err != nil {
   614  		return fmt.Errorf("failed to save theirDID to the did.ConnectionStore: %w", err)
   615  	}
   616  
   617  	err = s.connectionStore.SaveDIDByResolving(record.MyDID)
   618  	if err != nil {
   619  		return fmt.Errorf("failed to save myDID to the did.ConnectionStore: %w", err)
   620  	}
   621  
   622  	if isDIDCommV2(record.MediaTypeProfiles) {
   623  		record.DIDCommVersion = service.V2
   624  	} else {
   625  		record.DIDCommVersion = service.V1
   626  	}
   627  
   628  	return s.connectionRecorder.SaveConnectionRecord(record)
   629  }
   630  
   631  func (s *Service) connectionRecord(msg service.DIDCommMsg, ctx service.DIDCommContext) (*connection.Record, error) {
   632  	switch msg.Type() {
   633  	case InvitationMsgType:
   634  		return s.invitationMsgRecord(msg)
   635  	case RequestMsgType:
   636  		return s.requestMsgRecord(msg, ctx)
   637  	case ResponseMsgType:
   638  		return s.responseMsgRecord(msg)
   639  	case AckMsgType:
   640  		return s.fetchConnectionRecord(theirNSPrefix, msg)
   641  	}
   642  
   643  	return nil, errors.New("invalid message type")
   644  }
   645  
   646  func (s *Service) invitationMsgRecord(msg service.DIDCommMsg) (*connection.Record, error) {
   647  	thID, msgErr := msg.ThreadID()
   648  	if msgErr != nil {
   649  		return nil, msgErr
   650  	}
   651  
   652  	invitation := &Invitation{}
   653  
   654  	err := msg.Decode(invitation)
   655  	if err != nil {
   656  		return nil, err
   657  	}
   658  
   659  	recKey, err := s.ctx.getInvitationRecipientKey(invitation)
   660  	if err != nil {
   661  		return nil, err
   662  	}
   663  
   664  	connRecord := &connection.Record{
   665  		ConnectionID:    generateRandomID(),
   666  		ThreadID:        thID,
   667  		State:           stateNameNull,
   668  		InvitationID:    invitation.ID,
   669  		InvitationDID:   invitation.DID,
   670  		ServiceEndPoint: model.NewDIDCommV1Endpoint(invitation.ServiceEndpoint),
   671  		RecipientKeys:   []string{recKey},
   672  		TheirLabel:      invitation.Label,
   673  		Namespace:       findNamespace(msg.Type()),
   674  		DIDCommVersion:  service.V1,
   675  	}
   676  
   677  	if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil {
   678  		return nil, err
   679  	}
   680  
   681  	return connRecord, nil
   682  }
   683  
   684  func (s *Service) requestMsgRecord(msg service.DIDCommMsg, ctx service.DIDCommContext) (*connection.Record, error) {
   685  	var (
   686  		request          Request
   687  		invitationRecKey []string
   688  	)
   689  
   690  	err := msg.Decode(&request)
   691  	if err != nil {
   692  		return nil, fmt.Errorf("unmarshalling failed: %w", err)
   693  	}
   694  
   695  	invitationID := msg.ParentThreadID()
   696  	if invitationID == "" {
   697  		// try to retrieve key which was assigned at HandleInboundEnvelope method of inbound handler
   698  		if verKey, ok := ctx.All()[InvitationRecipientKey].(string); ok && verKey != "" {
   699  			invitationRecKey = append(invitationRecKey, verKey)
   700  		} else {
   701  			return nil, fmt.Errorf("missing parent thread ID and invitation recipient key"+
   702  				" on connection request with @id=%s", request.ID)
   703  		}
   704  	}
   705  
   706  	if request.Connection == nil {
   707  		return nil, fmt.Errorf("missing connection field on connection request with @id=%s", request.ID)
   708  	}
   709  
   710  	connRecord := &connection.Record{
   711  		TheirLabel:              request.Label,
   712  		ConnectionID:            generateRandomID(),
   713  		ThreadID:                request.ID,
   714  		State:                   stateNameNull,
   715  		TheirDID:                request.Connection.DID,
   716  		InvitationID:            invitationID,
   717  		InvitationRecipientKeys: invitationRecKey,
   718  		Namespace:               theirNSPrefix,
   719  		DIDCommVersion:          service.V1,
   720  	}
   721  
   722  	if err := s.connectionRecorder.SaveConnectionRecord(connRecord); err != nil {
   723  		return nil, err
   724  	}
   725  
   726  	return connRecord, nil
   727  }
   728  
   729  func (s *Service) responseMsgRecord(payload service.DIDCommMsg) (*connection.Record, error) {
   730  	return s.fetchConnectionRecord(myNSPrefix, payload)
   731  }
   732  
   733  func (s *Service) fetchConnectionRecord(nsPrefix string, payload service.DIDCommMsg) (*connection.Record, error) {
   734  	msg := &struct {
   735  		Thread decorator.Thread `json:"~thread,omitempty"`
   736  	}{}
   737  
   738  	err := payload.Decode(msg)
   739  	if err != nil {
   740  		return nil, err
   741  	}
   742  
   743  	key, err := connection.CreateNamespaceKey(nsPrefix, msg.Thread.ID)
   744  	if err != nil {
   745  		return nil, err
   746  	}
   747  
   748  	return s.connectionRecorder.GetConnectionRecordByNSThreadID(key)
   749  }
   750  
   751  func generateRandomID() string {
   752  	return uuid.New().String()
   753  }
   754  
   755  // canTriggerActionEvents true based on role and state.
   756  // 1. Role is invitee and state is invited.
   757  // 2. Role is inviter and state is requested.
   758  func canTriggerActionEvents(stateID, ns string) bool {
   759  	return (stateID == StateIDInvited && ns == myNSPrefix) || (stateID == StateIDRequested && ns == theirNSPrefix)
   760  }
   761  
   762  // CreateImplicitInvitation creates implicit invitation. Inviter DID is required, invitee DID is optional.
   763  // If invitee DID is not provided new peer DID will be created for implicit invitation connection request.
   764  //nolint:funlen
   765  func (s *Service) CreateImplicitInvitation(inviterLabel, inviterDID,
   766  	inviteeLabel, inviteeDID string, routerConnections []string) (string, error) {
   767  	logger.Debugf("implicit invitation requested inviterDID[%s] inviteeDID[%s]", inviterDID, inviteeDID)
   768  
   769  	docResolution, err := s.ctx.vdRegistry.Resolve(inviterDID)
   770  	if err != nil {
   771  		return "", fmt.Errorf("resolve public did[%s]: %w", inviterDID, err)
   772  	}
   773  
   774  	dest, err := service.CreateDestination(docResolution.DIDDocument)
   775  	if err != nil {
   776  		return "", err
   777  	}
   778  
   779  	thID := generateRandomID()
   780  
   781  	var connRecord *connection.Record
   782  
   783  	if accept, e := dest.ServiceEndpoint.Accept(); e == nil && isDIDCommV2(accept) {
   784  		connRecord = &connection.Record{
   785  			ConnectionID:    generateRandomID(),
   786  			ThreadID:        thID,
   787  			State:           stateNameNull,
   788  			InvitationDID:   inviterDID,
   789  			Implicit:        true,
   790  			ServiceEndPoint: dest.ServiceEndpoint,
   791  			RecipientKeys:   dest.RecipientKeys,
   792  			TheirLabel:      inviterLabel,
   793  			Namespace:       findNamespace(InvitationMsgType),
   794  		}
   795  	} else {
   796  		connRecord = &connection.Record{
   797  			ConnectionID:      generateRandomID(),
   798  			ThreadID:          thID,
   799  			State:             stateNameNull,
   800  			InvitationDID:     inviterDID,
   801  			Implicit:          true,
   802  			ServiceEndPoint:   dest.ServiceEndpoint,
   803  			RecipientKeys:     dest.RecipientKeys,
   804  			RoutingKeys:       dest.RoutingKeys,
   805  			MediaTypeProfiles: dest.MediaTypeProfiles,
   806  			TheirLabel:        inviterLabel,
   807  			Namespace:         findNamespace(InvitationMsgType),
   808  		}
   809  	}
   810  
   811  	if e := s.connectionRecorder.SaveConnectionRecordWithMappings(connRecord); e != nil {
   812  		return "", fmt.Errorf("failed to save new connection record for implicit invitation: %w", e)
   813  	}
   814  
   815  	invitation := &Invitation{
   816  		ID:    uuid.New().String(),
   817  		Label: inviterLabel,
   818  		DID:   inviterDID,
   819  		Type:  InvitationMsgType,
   820  	}
   821  
   822  	msg, err := createDIDCommMsg(invitation)
   823  	if err != nil {
   824  		return "", fmt.Errorf("failed to create DIDCommMsg for implicit invitation: %w", err)
   825  	}
   826  
   827  	next := &requested{}
   828  	internalMsg := &message{
   829  		Msg:           msg.Clone(),
   830  		ThreadID:      thID,
   831  		NextStateName: next.Name(),
   832  		ConnRecord:    connRecord,
   833  	}
   834  	internalMsg.Options = &options{publicDID: inviteeDID, label: inviteeLabel, routerConnections: routerConnections}
   835  
   836  	go func(msg *message, aEvent chan<- service.DIDCommAction) {
   837  		if err = s.handle(msg, aEvent); err != nil {
   838  			logger.Errorf("error from handle for implicit invitation: %s", err)
   839  		}
   840  	}(internalMsg, s.ActionEvent())
   841  
   842  	return connRecord.ConnectionID, nil
   843  }
   844  
   845  func createDIDCommMsg(invitation *Invitation) (service.DIDCommMsg, error) {
   846  	payload, err := json.Marshal(invitation)
   847  	if err != nil {
   848  		return nil, fmt.Errorf("marshal invitation: %w", err)
   849  	}
   850  
   851  	return service.ParseDIDCommMsgMap(payload)
   852  }