github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/presentproof/service.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package presentproof
     8  
     9  import (
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"strings"
    14  
    15  	"github.com/google/uuid"
    16  
    17  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    18  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    19  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    20  	"github.com/hyperledger/aries-framework-go/pkg/doc/verifiable"
    21  	"github.com/hyperledger/aries-framework-go/spi/storage"
    22  )
    23  
    24  const (
    25  	// Name defines the protocol name.
    26  	Name = "present-proof"
    27  	// SpecV2 defines the protocol spec.
    28  	SpecV2 = "https://didcomm.org/present-proof/2.0/"
    29  	// ProposePresentationMsgTypeV2 defines the protocol propose-presentation message type.
    30  	ProposePresentationMsgTypeV2 = SpecV2 + "propose-presentation"
    31  	// RequestPresentationMsgTypeV2 defines the protocol request-presentation message type.
    32  	RequestPresentationMsgTypeV2 = SpecV2 + "request-presentation"
    33  	// PresentationMsgTypeV2 defines the protocol presentation message type.
    34  	PresentationMsgTypeV2 = SpecV2 + "presentation"
    35  	// AckMsgTypeV2 defines the protocol ack message type.
    36  	AckMsgTypeV2 = SpecV2 + "ack"
    37  	// ProblemReportMsgTypeV2 defines the protocol problem-report message type.
    38  	ProblemReportMsgTypeV2 = SpecV2 + "problem-report"
    39  	// PresentationPreviewMsgTypeV2 defines the protocol presentation-preview inner object type.
    40  	PresentationPreviewMsgTypeV2 = SpecV2 + "presentation-preview"
    41  
    42  	// SpecV3 defines the protocol spec.
    43  	SpecV3 = "https://didcomm.org/present-proof/3.0/"
    44  	// ProposePresentationMsgTypeV3 defines the protocol propose-presentation message type.
    45  	ProposePresentationMsgTypeV3 = SpecV3 + "propose-presentation"
    46  	// RequestPresentationMsgTypeV3 defines the protocol request-presentation message type.
    47  	RequestPresentationMsgTypeV3 = SpecV3 + "request-presentation"
    48  	// PresentationMsgTypeV3 defines the protocol presentation message type.
    49  	PresentationMsgTypeV3 = SpecV3 + "presentation"
    50  	// AckMsgTypeV3 defines the protocol ack message type.
    51  	AckMsgTypeV3 = SpecV3 + "ack"
    52  	// ProblemReportMsgTypeV3 defines the protocol problem-report message type.
    53  	ProblemReportMsgTypeV3 = SpecV3 + "problem-report"
    54  	// PresentationPreviewMsgTypeV3 defines the protocol presentation-preview inner object type.
    55  	PresentationPreviewMsgTypeV3 = SpecV3 + "presentation-preview"
    56  )
    57  
    58  const (
    59  	internalDataKey        = "internal_data_"
    60  	transitionalPayloadKey = "transitionalPayload_%s"
    61  )
    62  
    63  type version string
    64  
    65  const (
    66  	version2 = version("present-proof V2")
    67  	version3 = version("present-proof V3")
    68  )
    69  
    70  // nolint:gochecknoglobals
    71  var (
    72  	logger         = log.New("aries-framework/presentproof/service")
    73  	initialHandler = HandlerFunc(func(_ Metadata) error {
    74  		return nil
    75  	})
    76  	errProtocolStopped = errors.New("protocol was stopped")
    77  )
    78  
    79  // customError is a wrapper to determine custom error against internal error.
    80  type customError struct{ error }
    81  
    82  // transitionalPayload keeps payload needed for Continue function to proceed with the action.
    83  type transitionalPayload struct {
    84  	Action
    85  	StateName       string
    86  	AckRequired     bool
    87  	Direction       messageDirection
    88  	ProtocolVersion version
    89  	Properties      map[string]interface{}
    90  }
    91  
    92  type messageDirection string
    93  
    94  const (
    95  	inboundMessage  = messageDirection("InboundMessage")
    96  	outboundMessage = messageDirection("OutboundMessage")
    97  )
    98  
    99  // metaData type to store data for internal usage.
   100  type metaData struct {
   101  	transitionalPayload
   102  	state                 state
   103  	presentationNames     []string
   104  	properties            map[string]interface{}
   105  	msgClone              service.DIDCommMsg
   106  	presentation          *PresentationV2
   107  	proposePresentation   *ProposePresentationV2
   108  	request               *RequestPresentationV2
   109  	presentationV3        *PresentationV3
   110  	proposePresentationV3 *ProposePresentationV3
   111  	requestV3             *RequestPresentationV3
   112  
   113  	addProofFn func(presentation *verifiable.Presentation) error
   114  	// err is used to determine whether callback was stopped
   115  	// e.g the user received an action event and executes Stop(err) function
   116  	// in that case `err` is equal to `err` which was passing to Stop function
   117  	err error
   118  }
   119  
   120  func (md *metaData) Message() service.DIDCommMsg {
   121  	return md.msgClone
   122  }
   123  
   124  func (md *metaData) Presentation() *PresentationV2 {
   125  	return md.presentation
   126  }
   127  
   128  func (md *metaData) PresentationV3() *PresentationV3 {
   129  	return md.presentationV3
   130  }
   131  
   132  func (md *metaData) ProposePresentation() *ProposePresentationV2 {
   133  	return md.proposePresentation
   134  }
   135  
   136  func (md *metaData) ProposePresentationV3() *ProposePresentationV3 {
   137  	return md.proposePresentationV3
   138  }
   139  
   140  func (md *metaData) RequestPresentation() *RequestPresentationV2 {
   141  	return md.request
   142  }
   143  
   144  func (md *metaData) RequestPresentationV3() *RequestPresentationV3 {
   145  	return md.requestV3
   146  }
   147  
   148  func (md *metaData) PresentationNames() []string {
   149  	return md.presentationNames
   150  }
   151  
   152  func (md *metaData) StateName() string {
   153  	return md.state.Name()
   154  }
   155  
   156  func (md *metaData) Properties() map[string]interface{} {
   157  	return md.properties
   158  }
   159  
   160  func (md *metaData) GetAddProofFn() func(presentation *verifiable.Presentation) error {
   161  	return md.addProofFn
   162  }
   163  
   164  // Action contains helpful information about action.
   165  type Action struct {
   166  	// Protocol instance ID
   167  	PIID     string
   168  	Msg      service.DIDCommMsgMap
   169  	MyDID    string
   170  	TheirDID string
   171  }
   172  
   173  // Opt describes option signature for the Continue function.
   174  type Opt func(md *metaData)
   175  
   176  // WithPresentation allows providing Presentation message
   177  // USAGE: This message can be provided after receiving an Invitation message.
   178  func WithPresentation(pp *PresentationParams) Opt {
   179  	return func(md *metaData) {
   180  		switch md.ProtocolVersion {
   181  		default:
   182  			fallthrough
   183  		case version2:
   184  			md.presentation = &PresentationV2{
   185  				Type:                PresentationMsgTypeV2,
   186  				Comment:             pp.Comment,
   187  				Formats:             pp.Formats,
   188  				PresentationsAttach: decorator.GenericAttachmentsToV1(pp.Attachments),
   189  			}
   190  		case version3:
   191  			md.presentationV3 = &PresentationV3{
   192  				Type: PresentationMsgTypeV3,
   193  				Body: PresentationV3Body{
   194  					GoalCode: pp.GoalCode,
   195  					Comment:  pp.Comment,
   196  				},
   197  				Attachments: decorator.GenericAttachmentsToV2(pp.Attachments),
   198  			}
   199  		}
   200  	}
   201  }
   202  
   203  // WithAddProofFn allows providing function that will sign the Presentation.
   204  // USAGE: This fn can be provided after receiving a Invitation message.
   205  func WithAddProofFn(addProof func(presentation *verifiable.Presentation) error) Opt {
   206  	return func(md *metaData) {
   207  		md.addProofFn = addProof
   208  	}
   209  }
   210  
   211  // WithMultiOptions allows combining several options into one.
   212  func WithMultiOptions(opts ...Opt) Opt {
   213  	return func(md *metaData) {
   214  		for _, opt := range opts {
   215  			opt(md)
   216  		}
   217  	}
   218  }
   219  
   220  // WithProposePresentation allows providing ProposePresentation message
   221  // USAGE: This message can be provided after receiving an Invitation message.
   222  func WithProposePresentation(pp *ProposePresentationParams) Opt {
   223  	return func(md *metaData) {
   224  		switch md.ProtocolVersion {
   225  		default:
   226  			fallthrough
   227  		case version2:
   228  			md.proposePresentation = &ProposePresentationV2{
   229  				Type:            ProposePresentationMsgTypeV2,
   230  				Comment:         pp.Comment,
   231  				Formats:         pp.Formats,
   232  				ProposalsAttach: decorator.GenericAttachmentsToV1(pp.Attachments),
   233  			}
   234  		case version3:
   235  			md.proposePresentationV3 = &ProposePresentationV3{
   236  				Type: ProposePresentationMsgTypeV3,
   237  				Body: ProposePresentationV3Body{
   238  					GoalCode: pp.GoalCode,
   239  					Comment:  pp.Comment,
   240  				},
   241  				Attachments: decorator.GenericAttachmentsToV2(pp.Attachments),
   242  			}
   243  		}
   244  	}
   245  }
   246  
   247  // WithRequestPresentation allows providing RequestPresentation message
   248  // USAGE: This message can be provided after receiving a propose message.
   249  func WithRequestPresentation(msg *RequestPresentationParams) Opt {
   250  	return func(md *metaData) {
   251  		switch md.ProtocolVersion {
   252  		default:
   253  			fallthrough
   254  		case version2:
   255  			md.request = &RequestPresentationV2{
   256  				ID:                         uuid.New().String(),
   257  				Type:                       RequestPresentationMsgTypeV2,
   258  				Comment:                    msg.Comment,
   259  				WillConfirm:                msg.WillConfirm,
   260  				Formats:                    msg.Formats,
   261  				RequestPresentationsAttach: decorator.GenericAttachmentsToV1(msg.Attachments),
   262  			}
   263  		case version3:
   264  			md.requestV3 = &RequestPresentationV3{
   265  				ID:   uuid.New().String(),
   266  				Type: RequestPresentationMsgTypeV3,
   267  				Body: RequestPresentationV3Body{
   268  					GoalCode:    msg.GoalCode,
   269  					Comment:     msg.Comment,
   270  					WillConfirm: msg.WillConfirm,
   271  				},
   272  				Attachments: decorator.GenericAttachmentsToV2(msg.Attachments),
   273  			}
   274  		}
   275  	}
   276  }
   277  
   278  // WithFriendlyNames allows providing names for the presentations.
   279  func WithFriendlyNames(names ...string) Opt {
   280  	return func(md *metaData) {
   281  		md.presentationNames = names
   282  	}
   283  }
   284  
   285  // WithProperties allows providing custom properties.
   286  func WithProperties(props map[string]interface{}) Opt {
   287  	return func(md *metaData) {
   288  		if len(md.properties) == 0 {
   289  			md.properties = props
   290  
   291  			return
   292  		}
   293  
   294  		for k, v := range props {
   295  			md.properties[k] = v
   296  		}
   297  	}
   298  }
   299  
   300  // Provider contains dependencies for the protocol and is typically created by using aries.Context().
   301  type Provider interface {
   302  	Messenger() service.Messenger
   303  	StorageProvider() storage.Provider
   304  }
   305  
   306  // Service for the presentproof protocol.
   307  type Service struct {
   308  	service.Action
   309  	service.Message
   310  	store       storage.Store
   311  	callbacks   chan *metaData
   312  	messenger   service.Messenger
   313  	middleware  Handler
   314  	initialized bool
   315  }
   316  
   317  // New returns the presentproof service.
   318  func New(p Provider) (*Service, error) {
   319  	svc := Service{}
   320  
   321  	err := svc.Initialize(p)
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  
   326  	return &svc, nil
   327  }
   328  
   329  // Initialize initializes the Service. If Initialize succeeds, any further call is a no-op.
   330  func (s *Service) Initialize(prov interface{}) error {
   331  	if s.initialized {
   332  		return nil
   333  	}
   334  
   335  	p, ok := prov.(Provider)
   336  	if !ok {
   337  		return fmt.Errorf("expected provider of type `%T`, got type `%T`", Provider(nil), p)
   338  	}
   339  
   340  	store, err := p.StorageProvider().OpenStore(Name)
   341  	if err != nil {
   342  		return err
   343  	}
   344  
   345  	err = p.StorageProvider().SetStoreConfig(Name, storage.StoreConfiguration{TagNames: []string{transitionalPayloadKey}})
   346  	if err != nil {
   347  		return fmt.Errorf("failed to set store configuration: %w", err)
   348  	}
   349  
   350  	s.messenger = p.Messenger()
   351  	s.store = store
   352  	s.callbacks = make(chan *metaData)
   353  	s.middleware = initialHandler
   354  
   355  	// start the listener
   356  	go s.startInternalListener()
   357  
   358  	s.initialized = true
   359  
   360  	return nil
   361  }
   362  
   363  // Use allows providing middlewares.
   364  func (s *Service) Use(items ...Middleware) {
   365  	var handler Handler = initialHandler
   366  	for i := len(items) - 1; i >= 0; i-- {
   367  		handler = items[i](handler)
   368  	}
   369  
   370  	s.middleware = handler
   371  }
   372  
   373  // HandleInbound handles inbound message (presentproof protocol).
   374  func (s *Service) HandleInbound(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) {
   375  	logger.Debugf("service.HandleInbound() input: msg=%+v myDID=%s theirDID=%s", msg, ctx.MyDID(), ctx.TheirDID())
   376  
   377  	msgMap := msg.Clone()
   378  
   379  	aEvent := s.ActionEvent()
   380  
   381  	if aEvent == nil {
   382  		// throw error if there is no action event registered for inbound messages
   383  		return "", errors.New("no clients are registered to handle the message")
   384  	}
   385  
   386  	md, err := s.buildMetaData(msgMap, inboundMessage)
   387  	if err != nil {
   388  		return "", fmt.Errorf("buildMetaData: %w", err)
   389  	}
   390  
   391  	md.MyDID = ctx.MyDID()
   392  	md.TheirDID = ctx.TheirDID()
   393  
   394  	// trigger action event based on message type for inbound messages
   395  	if canTriggerActionEvents(msgMap) {
   396  		err = s.saveTransitionalPayload(md.PIID, &(md.transitionalPayload))
   397  		if err != nil {
   398  			return "", fmt.Errorf("save transitional payload: %w", err)
   399  		}
   400  		aEvent <- s.newDIDCommActionMsg(md)
   401  
   402  		return "", nil
   403  	}
   404  
   405  	thid, err := msgMap.ThreadID()
   406  	if err != nil {
   407  		return "", fmt.Errorf("failed to obtain the message's threadID : %w", err)
   408  	}
   409  
   410  	// if no action event is triggered, continue the execution
   411  	return thid, s.handle(md)
   412  }
   413  
   414  // HandleOutbound handles outbound message (presentproof protocol).
   415  func (s *Service) HandleOutbound(msg service.DIDCommMsg, myDID, theirDID string) (string, error) {
   416  	logger.Debugf("service.HandleOutbound() input: msg=%+v myDID=%s theirDID=%s", msg, myDID, theirDID)
   417  
   418  	msgMap := msg.Clone()
   419  
   420  	md, err := s.buildMetaData(msgMap, outboundMessage)
   421  	if err != nil {
   422  		return "", fmt.Errorf("buildMetaData: %w", err)
   423  	}
   424  
   425  	md.MyDID = myDID
   426  	md.TheirDID = theirDID
   427  
   428  	thid, err := msgMap.ThreadID()
   429  	if err != nil {
   430  		return "", fmt.Errorf("failed to obtain the message's threadID : %w", err)
   431  	}
   432  
   433  	// if no action event is triggered, continue the execution
   434  	return thid, s.handle(md)
   435  }
   436  
   437  func (s *Service) getCurrentInternalDataAndPIID(msg service.DIDCommMsgMap) (string, *internalData, error) {
   438  	var protocolVersion version
   439  
   440  	isV2, err := service.IsDIDCommV2(&msg)
   441  	if err != nil {
   442  		return "", nil, fmt.Errorf("checking message version: %w", err)
   443  	}
   444  
   445  	if isV2 {
   446  		protocolVersion = version3
   447  	} else {
   448  		protocolVersion = version2
   449  	}
   450  
   451  	piID, err := getPIID(msg)
   452  	if errors.Is(err, service.ErrThreadIDNotFound) {
   453  		msg.SetID(uuid.New().String(), service.WithVersion(getDIDVersion(getVersion(msg.Type()))))
   454  
   455  		return msg.ID(), &internalData{StateName: stateNameStart, ProtocolVersion: protocolVersion}, nil
   456  	}
   457  
   458  	if err != nil {
   459  		return "", nil, fmt.Errorf("piID: %w", err)
   460  	}
   461  
   462  	data, err := s.currentInternalData(piID, protocolVersion)
   463  	if err != nil {
   464  		return "", nil, fmt.Errorf("current internal data: %w", err)
   465  	}
   466  
   467  	return piID, data, nil
   468  }
   469  
   470  func (s *Service) buildMetaData(msg service.DIDCommMsgMap, direction messageDirection) (*metaData, error) {
   471  	piID, data, err := s.getCurrentInternalDataAndPIID(msg)
   472  	if err != nil {
   473  		return nil, fmt.Errorf("current internal data and PIID: %w", err)
   474  	}
   475  
   476  	current := stateFromName(data.StateName, getVersion(msg.Type()))
   477  
   478  	next, err := nextState(msg, direction)
   479  	if err != nil {
   480  		return nil, fmt.Errorf("nextState: %w", err)
   481  	}
   482  
   483  	if !current.CanTransitionTo(next) {
   484  		return nil, fmt.Errorf("invalid state transition: %s -> %s", current.Name(), next.Name())
   485  	}
   486  
   487  	return &metaData{
   488  		transitionalPayload: transitionalPayload{
   489  			StateName:   next.Name(),
   490  			AckRequired: data.AckRequired,
   491  			Action: Action{
   492  				Msg:  msg,
   493  				PIID: piID,
   494  			},
   495  			Direction:       direction,
   496  			ProtocolVersion: data.ProtocolVersion,
   497  			Properties:      next.Properties(),
   498  		},
   499  		properties: next.Properties(),
   500  		state:      next,
   501  		msgClone:   msg.Clone(),
   502  	}, nil
   503  }
   504  
   505  // startInternalListener listens to messages in go channel for callback messages from clients.
   506  func (s *Service) startInternalListener() {
   507  	for msg := range s.callbacks {
   508  		// if no error do handle
   509  		if msg.err == nil {
   510  			msg.err = s.handle(msg)
   511  		}
   512  
   513  		// no error - continue
   514  		if msg.err == nil {
   515  			continue
   516  		}
   517  
   518  		logger.Errorf("failed to handle msgID=%s : %s", msg.Msg.ID(), msg.err)
   519  
   520  		msg.state = &abandoned{V: getVersion(msg.Msg.Type()), Code: codeInternalError}
   521  
   522  		if err := s.handle(msg); err != nil {
   523  			logger.Errorf("listener handle: %s", err)
   524  		}
   525  	}
   526  }
   527  
   528  func isNoOp(s state) bool {
   529  	_, ok := s.(*noOp)
   530  	return ok
   531  }
   532  
   533  func (s *Service) handle(md *metaData) error {
   534  	current := md.state
   535  
   536  	for !isNoOp(current) {
   537  		next, action, err := s.execute(current, md)
   538  		if err != nil {
   539  			return fmt.Errorf("execute: %w", err)
   540  		}
   541  
   542  		if !isNoOp(next) && !current.CanTransitionTo(next) {
   543  			return fmt.Errorf("invalid state transition: %s --> %s", current.Name(), next.Name())
   544  		}
   545  
   546  		// WARN: md.ackRequired is being modified by requestSent state
   547  		data := &internalData{
   548  			StateName:       current.Name(),
   549  			AckRequired:     md.AckRequired,
   550  			ProtocolVersion: md.ProtocolVersion,
   551  		}
   552  
   553  		if err := s.saveInternalData(md.PIID, data); err != nil {
   554  			return fmt.Errorf("failed to persist state %s: %w", current.Name(), err)
   555  		}
   556  
   557  		if err := action(s.messenger); err != nil {
   558  			return fmt.Errorf("action %s: %w", md.state.Name(), err)
   559  		}
   560  
   561  		current = next
   562  	}
   563  
   564  	return nil
   565  }
   566  
   567  func getPIID(msg service.DIDCommMsg) (string, error) {
   568  	// pthid is needed for problem-report message
   569  	pthID := msg.ParentThreadID()
   570  	if pthID != "" && (msg.Type() == ProblemReportMsgTypeV2 || msg.Type() == ProblemReportMsgTypeV3) {
   571  		return pthID, nil
   572  	}
   573  
   574  	return msg.ThreadID()
   575  }
   576  
   577  type internalData struct {
   578  	AckRequired     bool
   579  	StateName       string
   580  	ProtocolVersion version
   581  }
   582  
   583  func (s *Service) saveInternalData(piID string, data *internalData) error {
   584  	src, err := json.Marshal(data)
   585  	if err != nil {
   586  		return err
   587  	}
   588  
   589  	return s.store.Put(internalDataKey+piID, src)
   590  }
   591  
   592  func (s *Service) currentInternalData(piID string, protocolVersion version) (*internalData, error) {
   593  	src, err := s.store.Get(internalDataKey + piID)
   594  	if errors.Is(err, storage.ErrDataNotFound) {
   595  		return &internalData{StateName: stateNameStart, ProtocolVersion: protocolVersion}, nil
   596  	}
   597  
   598  	if err != nil {
   599  		return nil, err
   600  	}
   601  
   602  	var data *internalData
   603  	if err := json.Unmarshal(src, &data); err != nil {
   604  		return nil, err
   605  	}
   606  
   607  	return data, nil
   608  }
   609  
   610  // stateFromName returns the state by given name.
   611  func stateFromName(name, v string) state {
   612  	switch name {
   613  	case stateNameStart:
   614  		return &start{}
   615  	case StateNameAbandoned:
   616  		return &abandoned{V: v}
   617  	case StateNameDone:
   618  		return &done{V: v}
   619  	case stateNameRequestSent:
   620  		return &requestSent{V: v}
   621  	case stateNamePresentationReceived:
   622  		return &presentationReceived{V: v}
   623  	case stateNameProposalReceived:
   624  		return &proposalReceived{V: v}
   625  	case stateNameRequestReceived:
   626  		return &requestReceived{V: v}
   627  	case stateNamePresentationSent:
   628  		return &presentationSent{V: v}
   629  	case stateNameProposalSent:
   630  		return &proposalSent{V: v}
   631  	default:
   632  		return &noOp{}
   633  	}
   634  }
   635  
   636  func nextState(msg service.DIDCommMsgMap, direction messageDirection) (state, error) {
   637  	switch msg.Type() {
   638  	case RequestPresentationMsgTypeV2, RequestPresentationMsgTypeV3:
   639  		switch direction {
   640  		case inboundMessage:
   641  			return &requestReceived{V: getVersion(msg.Type())}, nil
   642  		case outboundMessage:
   643  			return &requestSent{V: getVersion(msg.Type())}, nil
   644  		}
   645  	case ProposePresentationMsgTypeV2, ProposePresentationMsgTypeV3:
   646  		switch direction {
   647  		case inboundMessage:
   648  			return &proposalReceived{V: getVersion(msg.Type())}, nil
   649  		case outboundMessage:
   650  			return &proposalSent{V: getVersion(msg.Type())}, nil
   651  		}
   652  	case PresentationMsgTypeV2, PresentationMsgTypeV3:
   653  		return &presentationReceived{V: getVersion(msg.Type())}, nil
   654  	case ProblemReportMsgTypeV2, ProblemReportMsgTypeV3:
   655  		return &abandoned{V: getVersion(msg.Type()), properties: redirectInfo(msg)}, nil
   656  	case AckMsgTypeV2, AckMsgTypeV3:
   657  		return &done{V: getVersion(msg.Type()), properties: redirectInfo(msg)}, nil
   658  	}
   659  
   660  	return nil, fmt.Errorf("unrecognized msgType: %s", msg.Type())
   661  }
   662  
   663  func getVersion(t string) string {
   664  	if strings.HasPrefix(t, SpecV2) {
   665  		return SpecV2
   666  	}
   667  
   668  	return SpecV3
   669  }
   670  
   671  func redirectInfo(msg service.DIDCommMsgMap) map[string]interface{} {
   672  	if redirectInfo, ok := msg[webRedirect].(map[string]interface{}); ok {
   673  		return redirectInfo
   674  	}
   675  
   676  	if redirectInfo, ok := msg[webRedirectV2].(map[string]interface{}); ok {
   677  		return redirectInfo
   678  	}
   679  
   680  	return map[string]interface{}{}
   681  }
   682  
   683  func getDIDVersion(v string) service.Version {
   684  	if v == SpecV3 {
   685  		return service.V2
   686  	}
   687  
   688  	return service.V1
   689  }
   690  
   691  func (s *Service) saveTransitionalPayload(id string, data *transitionalPayload) error {
   692  	src, err := json.Marshal(*data)
   693  	if err != nil {
   694  		return fmt.Errorf("marshal transitional payload: %w", err)
   695  	}
   696  
   697  	return s.store.Put(fmt.Sprintf(transitionalPayloadKey, id), src, storage.Tag{Name: transitionalPayloadKey})
   698  }
   699  
   700  // canTriggerActionEvents checks if the incoming message can trigger an action event.
   701  func canTriggerActionEvents(msg service.DIDCommMsg) bool {
   702  	return msg.Type() == PresentationMsgTypeV2 ||
   703  		msg.Type() == ProposePresentationMsgTypeV2 ||
   704  		msg.Type() == RequestPresentationMsgTypeV2 ||
   705  		msg.Type() == ProblemReportMsgTypeV2 ||
   706  		msg.Type() == PresentationMsgTypeV3 ||
   707  		msg.Type() == ProposePresentationMsgTypeV3 ||
   708  		msg.Type() == RequestPresentationMsgTypeV3 ||
   709  		msg.Type() == ProblemReportMsgTypeV3
   710  }
   711  
   712  func (s *Service) getTransitionalPayload(id string) (*transitionalPayload, error) {
   713  	src, err := s.store.Get(fmt.Sprintf(transitionalPayloadKey, id))
   714  	if err != nil {
   715  		return nil, fmt.Errorf("store get: %w", err)
   716  	}
   717  
   718  	t := &transitionalPayload{}
   719  
   720  	err = json.Unmarshal(src, t)
   721  	if err != nil {
   722  		return nil, fmt.Errorf("unmarshal transitional payload: %w", err)
   723  	}
   724  
   725  	return t, err
   726  }
   727  
   728  func (s *Service) deleteTransitionalPayload(id string) error {
   729  	return s.store.Delete(fmt.Sprintf(transitionalPayloadKey, id))
   730  }
   731  
   732  // Actions returns actions for the async usage.
   733  func (s *Service) Actions() ([]Action, error) {
   734  	records, err := s.store.Query(transitionalPayloadKey)
   735  	if err != nil {
   736  		return nil, fmt.Errorf("failed to query store: %w", err)
   737  	}
   738  
   739  	defer storage.Close(records, logger)
   740  
   741  	var actions []Action
   742  
   743  	more, err := records.Next()
   744  	if err != nil {
   745  		return nil, fmt.Errorf("failed to get next set of data from records: %w", err)
   746  	}
   747  
   748  	for more {
   749  		value, err := records.Value()
   750  		if err != nil {
   751  			return nil, fmt.Errorf("failed to get value from records: %w", err)
   752  		}
   753  
   754  		var action Action
   755  		if errUnmarshal := json.Unmarshal(value, &action); errUnmarshal != nil {
   756  			return nil, fmt.Errorf("unmarshal: %w", errUnmarshal)
   757  		}
   758  
   759  		actions = append(actions, action)
   760  
   761  		more, err = records.Next()
   762  		if err != nil {
   763  			return nil, fmt.Errorf("failed to get next set of data from records: %w", err)
   764  		}
   765  	}
   766  
   767  	return actions, nil
   768  }
   769  
   770  // ActionContinue allows proceeding with the action by the piID.
   771  func (s *Service) ActionContinue(piID string, opts ...Opt) error {
   772  	tPayload, err := s.getTransitionalPayload(piID)
   773  	if err != nil {
   774  		return fmt.Errorf("get transitional payload: %w", err)
   775  	}
   776  
   777  	md := &metaData{
   778  		transitionalPayload: *tPayload,
   779  		state:               stateFromName(tPayload.StateName, getVersion(tPayload.Msg.Type())),
   780  		msgClone:            tPayload.Msg.Clone(),
   781  		properties:          tPayload.Properties,
   782  	}
   783  
   784  	for _, opt := range opts {
   785  		opt(md)
   786  	}
   787  
   788  	if err := s.deleteTransitionalPayload(md.PIID); err != nil {
   789  		return fmt.Errorf("delete transitional payload: %w", err)
   790  	}
   791  
   792  	s.processCallback(md)
   793  
   794  	return nil
   795  }
   796  
   797  // ActionStop allows stopping the action by the piID.
   798  func (s *Service) ActionStop(piID string, cErr error, opts ...Opt) error {
   799  	tPayload, err := s.getTransitionalPayload(piID)
   800  	if err != nil {
   801  		return fmt.Errorf("get transitional payload: %w", err)
   802  	}
   803  
   804  	md := &metaData{
   805  		transitionalPayload: *tPayload,
   806  		state:               stateFromName(tPayload.StateName, tPayload.Msg.Type()),
   807  		msgClone:            tPayload.Msg.Clone(),
   808  		properties:          tPayload.Properties,
   809  	}
   810  
   811  	for _, opt := range opts {
   812  		opt(md)
   813  	}
   814  
   815  	if err := s.deleteTransitionalPayload(md.PIID); err != nil {
   816  		return fmt.Errorf("delete transitional payload: %w", err)
   817  	}
   818  
   819  	if cErr == nil {
   820  		cErr = errProtocolStopped
   821  	}
   822  
   823  	md.err = customError{error: cErr}
   824  	s.processCallback(md)
   825  
   826  	return nil
   827  }
   828  
   829  func (s *Service) processCallback(msg *metaData) {
   830  	// pass the callback data to internal channel. This is created to unblock consumer go routine and wrap the callback
   831  	// channel internally.
   832  	s.callbacks <- msg
   833  }
   834  
   835  // newDIDCommActionMsg creates new DIDCommAction message.
   836  func (s *Service) newDIDCommActionMsg(md *metaData) service.DIDCommAction {
   837  	// create the message for the channel
   838  	// trigger the registered action event
   839  	return service.DIDCommAction{
   840  		ProtocolName: Name,
   841  		Message:      md.msgClone,
   842  		Continue: func(opt interface{}) {
   843  			if fn, ok := opt.(Opt); ok {
   844  				fn(md)
   845  			}
   846  
   847  			if err := s.deleteTransitionalPayload(md.PIID); err != nil {
   848  				logger.Errorf("continue: delete transitional payload: %v", err)
   849  			}
   850  
   851  			s.processCallback(md)
   852  		},
   853  		Stop: func(cErr error) {
   854  			if err := s.deleteTransitionalPayload(md.PIID); err != nil {
   855  				logger.Errorf("stop: delete transitional payload: %v", err)
   856  			}
   857  
   858  			if cErr == nil {
   859  				cErr = errProtocolStopped
   860  			}
   861  
   862  			md.err = customError{error: cErr}
   863  			s.processCallback(md)
   864  		},
   865  		Properties: newEventProps(md),
   866  	}
   867  }
   868  
   869  func (s *Service) execute(next state, md *metaData) (state, stateAction, error) {
   870  	md.state = next
   871  	s.sendMsgEvents(md, next.Name(), service.PreState)
   872  
   873  	defer s.sendMsgEvents(md, next.Name(), service.PostState)
   874  
   875  	md.properties = newEventProps(md).All()
   876  
   877  	if err := s.middleware.Handle(md); err != nil {
   878  		return nil, nil, fmt.Errorf("middleware: %w", err)
   879  	}
   880  
   881  	return next.Execute(md)
   882  }
   883  
   884  // sendMsgEvents triggers the message events.
   885  func (s *Service) sendMsgEvents(md *metaData, stateID string, stateType service.StateMsgType) {
   886  	// trigger the message events
   887  	for _, handler := range s.MsgEvents() {
   888  		handler <- service.StateMsg{
   889  			ProtocolName: Name,
   890  			Type:         stateType,
   891  			Msg:          md.msgClone,
   892  			StateID:      stateID,
   893  			Properties:   newEventProps(md),
   894  		}
   895  	}
   896  }
   897  
   898  // Name returns service name.
   899  func (s *Service) Name() string {
   900  	return Name
   901  }
   902  
   903  // Accept msg checks the msg type.
   904  func (s *Service) Accept(msgType string) bool {
   905  	switch msgType {
   906  	case ProposePresentationMsgTypeV2, RequestPresentationMsgTypeV2,
   907  		PresentationMsgTypeV2, AckMsgTypeV2, ProblemReportMsgTypeV2,
   908  		ProposePresentationMsgTypeV3, RequestPresentationMsgTypeV3,
   909  		PresentationMsgTypeV3, AckMsgTypeV3, ProblemReportMsgTypeV3:
   910  		return true
   911  	}
   912  
   913  	return false
   914  }