github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/dispatcher/inbound/inbound_message_handler.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  Copyright Avast Software. All Rights Reserved.
     4  
     5  SPDX-License-Identifier: Apache-2.0
     6  */
     7  
     8  package inbound
     9  
    10  import (
    11  	"encoding/json"
    12  	"errors"
    13  	"fmt"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/btcsuite/btcutil/base58"
    18  	"github.com/cenkalti/backoff/v4"
    19  
    20  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    21  	"github.com/hyperledger/aries-framework-go/pkg/crypto"
    22  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/middleware"
    23  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    24  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher"
    25  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange"
    26  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/legacyconnection"
    27  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    28  	"github.com/hyperledger/aries-framework-go/pkg/doc/did"
    29  	"github.com/hyperledger/aries-framework-go/pkg/framework/aries/api"
    30  	vdrapi "github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr"
    31  	didstore "github.com/hyperledger/aries-framework-go/pkg/store/did"
    32  	"github.com/hyperledger/aries-framework-go/pkg/vdr/fingerprint"
    33  )
    34  
    35  var logger = log.New("dispatcher/inbound")
    36  
    37  const (
    38  	kaIdentifier = "#"
    39  )
    40  
    41  // MessageHandler handles inbound envelopes, processing then dispatching to a protocol service based on the
    42  // message type.
    43  type MessageHandler struct {
    44  	didConnectionStore     didstore.ConnectionStore
    45  	didcommV2Handler       *middleware.DIDCommMessageMiddleware
    46  	msgSvcProvider         api.MessageServiceProvider
    47  	services               []dispatcher.ProtocolService
    48  	getDIDsBackOffDuration time.Duration
    49  	getDIDsMaxRetries      uint64
    50  	messenger              service.InboundMessenger
    51  	vdr                    vdrapi.Registry
    52  	initialized            bool
    53  }
    54  
    55  type provider interface {
    56  	DIDConnectionStore() didstore.ConnectionStore
    57  	MessageServiceProvider() api.MessageServiceProvider
    58  	AllServices() []dispatcher.ProtocolService
    59  	GetDIDsBackOffDuration() time.Duration
    60  	GetDIDsMaxRetries() uint64
    61  	InboundMessenger() service.InboundMessenger
    62  	DIDRotator() *middleware.DIDCommMessageMiddleware
    63  	VDRegistry() vdrapi.Registry
    64  }
    65  
    66  // NewInboundMessageHandler creates an inbound message handler, that processes inbound message Envelopes,
    67  // and dispatches them to the appropriate ProtocolService.
    68  func NewInboundMessageHandler(p provider) *MessageHandler {
    69  	h := MessageHandler{}
    70  	h.Initialize(p)
    71  
    72  	return &h
    73  }
    74  
    75  // Initialize initializes the MessageHandler. Any call beyond the first is a no-op.
    76  func (handler *MessageHandler) Initialize(p provider) {
    77  	if handler.initialized {
    78  		return
    79  	}
    80  
    81  	handler.didConnectionStore = p.DIDConnectionStore()
    82  	handler.msgSvcProvider = p.MessageServiceProvider()
    83  	handler.services = p.AllServices()
    84  	handler.getDIDsBackOffDuration = p.GetDIDsBackOffDuration()
    85  	handler.getDIDsMaxRetries = p.GetDIDsMaxRetries()
    86  	handler.messenger = p.InboundMessenger()
    87  	handler.didcommV2Handler = p.DIDRotator()
    88  	handler.vdr = p.VDRegistry()
    89  
    90  	handler.initialized = true
    91  }
    92  
    93  // HandlerFunc returns the MessageHandler's transport.InboundMessageHandler function.
    94  func (handler *MessageHandler) HandlerFunc() transport.InboundMessageHandler {
    95  	return func(envelope *transport.Envelope) error {
    96  		return handler.HandleInboundEnvelope(envelope)
    97  	}
    98  }
    99  
   100  // HandleInboundEnvelope handles an inbound envelope, dispatching it to the appropriate ProtocolService.
   101  func (handler *MessageHandler) HandleInboundEnvelope(envelope *transport.Envelope, // nolint:funlen,gocognit,gocyclo
   102  ) error {
   103  	var (
   104  		msg service.DIDCommMsgMap
   105  		err error
   106  	)
   107  
   108  	msg, err = service.ParseDIDCommMsgMap(envelope.Message)
   109  	if err != nil {
   110  		return err
   111  	}
   112  
   113  	isDIDEx := (&didexchange.Service{}).Accept(msg.Type())
   114  	isLegacyConn := (&legacyconnection.Service{}).Accept(msg.Type())
   115  
   116  	isV2, err := service.IsDIDCommV2(&msg)
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	var (
   122  		myDID, theirDID string
   123  		gotDIDs         bool
   124  	)
   125  
   126  	// handle inbound peer DID initial state
   127  	err = handler.didcommV2Handler.HandleInboundPeerDID(msg)
   128  	if err != nil {
   129  		return fmt.Errorf("handling inbound peer DID: %w", err)
   130  	}
   131  
   132  	// if msg is not a didexchange and legacy-connection message, do additional handling
   133  	if !isDIDEx && !isLegacyConn {
   134  		myDID, theirDID, err = handler.getDIDs(envelope, msg)
   135  		if err != nil {
   136  			return fmt.Errorf("get DIDs for message: %w", err)
   137  		}
   138  
   139  		gotDIDs = true
   140  
   141  		err = handler.didcommV2Handler.HandleInboundMessage(msg, theirDID, myDID)
   142  		if err != nil {
   143  			return fmt.Errorf("handle rotation: %w", err)
   144  		}
   145  	}
   146  
   147  	var foundService dispatcher.ProtocolService
   148  
   149  	// find the service which accepts the message type
   150  	for _, svc := range handler.services {
   151  		if svc.Accept(msg.Type()) {
   152  			foundService = svc
   153  			break
   154  		}
   155  	}
   156  
   157  	if foundService != nil {
   158  		props := make(map[string]interface{})
   159  
   160  		switch foundService.Name() {
   161  		// perf: DID exchange doesn't require myDID and theirDID
   162  		case didexchange.DIDExchange:
   163  		// perf: legacy-connection requires envelope.ToKey when sending Connection Response (it will sign connection
   164  		// data with this key)
   165  		case legacyconnection.LegacyConnection:
   166  			// When type of envelope.Message is connections/request, the key which was used to decrypt message is the
   167  			// same key which was sent during invitation. If ParentThreadID is missed (Interop issues), that key will be
   168  			// used to sign connection-data while sending connection response
   169  			if msg.Type() == legacyconnection.RequestMsgType && msg.ParentThreadID() == "" {
   170  				props[legacyconnection.InvitationRecipientKey] = base58.Encode(envelope.ToKey)
   171  			}
   172  		default:
   173  			if !gotDIDs {
   174  				myDID, theirDID, err = handler.getDIDs(envelope, msg)
   175  				if err != nil {
   176  					return fmt.Errorf("inbound message handler: %w", err)
   177  				}
   178  			}
   179  		}
   180  
   181  		_, err = foundService.HandleInbound(msg, service.NewDIDCommContext(myDID, theirDID, props))
   182  
   183  		return err
   184  	}
   185  
   186  	if !isV2 { // nolint:nestif
   187  		h := struct {
   188  			Purpose []string `json:"~purpose"`
   189  		}{}
   190  		err = msg.Decode(&h)
   191  
   192  		if err != nil {
   193  			return err
   194  		}
   195  
   196  		// in case of no services are registered for given message type, and message is didcomm v1,
   197  		// find generic inbound services registered for given message header
   198  		var foundMessageService dispatcher.MessageService
   199  
   200  		for _, svc := range handler.msgSvcProvider.Services() {
   201  			if svc.Accept(msg.Type(), h.Purpose) {
   202  				foundMessageService = svc
   203  			}
   204  		}
   205  
   206  		if foundMessageService != nil {
   207  			if !gotDIDs {
   208  				myDID, theirDID, err = handler.getDIDs(envelope, msg)
   209  				if err != nil {
   210  					return fmt.Errorf("inbound message handler: %w", err)
   211  				}
   212  			}
   213  
   214  			return handler.tryToHandle(foundMessageService, msg, service.NewDIDCommContext(myDID, theirDID, nil))
   215  		}
   216  	}
   217  
   218  	return fmt.Errorf("no message handlers found for the message type: %s", msg.Type())
   219  }
   220  
   221  func (handler *MessageHandler) getDIDs( // nolint:funlen,gocyclo,gocognit
   222  	envelope *transport.Envelope, message service.DIDCommMsgMap,
   223  ) (string, string, error) {
   224  	var (
   225  		myDID    string
   226  		theirDID string
   227  		err      error
   228  	)
   229  
   230  	myDID, err = handler.getDIDGivenKey(envelope.ToKey)
   231  	if err != nil {
   232  		return myDID, theirDID, err
   233  	}
   234  
   235  	theirDID, err = handler.getDIDGivenKey(envelope.FromKey)
   236  	if err != nil {
   237  		return myDID, theirDID, err
   238  	}
   239  
   240  	if len(envelope.FromKey) == 0 && message != nil && theirDID == "" {
   241  		if from, ok := message["from"].(string); ok {
   242  			didURL, e := did.ParseDIDURL(from)
   243  			if e == nil {
   244  				theirDID = didURL.DID.String()
   245  			}
   246  		}
   247  	}
   248  
   249  	return myDID, theirDID, backoff.Retry(func() error {
   250  		var notFound bool
   251  
   252  		if myDID == "" {
   253  			myDID, err = handler.didConnectionStore.GetDID(base58.Encode(envelope.ToKey))
   254  			if errors.Is(err, didstore.ErrNotFound) {
   255  				// try did:key
   256  				// CreateDIDKey below is for Ed25519 keys only, use the more general CreateDIDKeyByCode if other key
   257  				// types will be used. Currently, did:key is for legacy packers only, so only support Ed25519 keys.
   258  				didKey, _ := fingerprint.CreateDIDKey(envelope.ToKey)
   259  				myDID, err = handler.didConnectionStore.GetDID(didKey)
   260  			}
   261  
   262  			if errors.Is(err, didstore.ErrNotFound) {
   263  				notFound = true
   264  			} else if err != nil {
   265  				myDID = ""
   266  				return fmt.Errorf("failed to get my did: %w", err)
   267  			}
   268  		}
   269  
   270  		if envelope.FromKey == nil {
   271  			return nil
   272  		}
   273  
   274  		if theirDID == "" {
   275  			theirDID, err = handler.didConnectionStore.GetDID(base58.Encode(envelope.FromKey))
   276  			if errors.Is(err, didstore.ErrNotFound) {
   277  				// try did:key
   278  				// CreateDIDKey below is for Ed25519 keys only, use the more general CreateDIDKeyByCode if other key
   279  				// types will be used. Currently, did:key is for legacy packers, so only support Ed25519 keys.
   280  				didKey, _ := fingerprint.CreateDIDKey(envelope.FromKey)
   281  				theirDID, err = handler.didConnectionStore.GetDID(didKey)
   282  			}
   283  
   284  			if err == nil {
   285  				return nil
   286  			}
   287  
   288  			if notFound && errors.Is(err, didstore.ErrNotFound) {
   289  				// if neither DID is found, using either base58 key or did:key as lookup key
   290  				return nil
   291  			}
   292  
   293  			theirDID = ""
   294  			return fmt.Errorf("failed to get their did: %w", err)
   295  		}
   296  
   297  		return nil
   298  	}, backoff.WithMaxRetries(backoff.NewConstantBackOff(handler.getDIDsBackOffDuration), handler.getDIDsMaxRetries))
   299  }
   300  
   301  // getDIDGivenKey returns a did:key if the input key is a JWK. If the input key is not a JWK, returns the empty string.
   302  // An error is returned if the key is a JWK but fails to be converted to a did:key.
   303  func (handler *MessageHandler) getDIDGivenKey(key []byte) (string, error) {
   304  	var (
   305  		err    error
   306  		retDID string
   307  	)
   308  
   309  	// nolint: gocritic
   310  	if strings.Index(string(key), kaIdentifier) > 0 &&
   311  		strings.Index(string(key), "\"kid\":\"did:") > 0 {
   312  		retDID, err = pubKeyToDID(key)
   313  		if err != nil {
   314  			return "", fmt.Errorf("getDID: %w", err)
   315  		}
   316  
   317  		logger.Debugf("envelope Key as DID: %v", retDID)
   318  
   319  		return retDID, nil
   320  	}
   321  
   322  	return "", nil
   323  }
   324  
   325  func pubKeyToDID(key []byte) (string, error) {
   326  	toKey := &crypto.PublicKey{}
   327  
   328  	err := json.Unmarshal(key, toKey)
   329  	if err != nil {
   330  		return "", fmt.Errorf("pubKeyToDID: unmarshal key: %w", err)
   331  	}
   332  
   333  	return toKey.KID[:strings.Index(toKey.KID, kaIdentifier)], nil
   334  }
   335  
   336  func (handler *MessageHandler) tryToHandle(
   337  	svc service.InboundHandler, msg service.DIDCommMsgMap, ctx service.DIDCommContext) error {
   338  	if err := handler.messenger.HandleInbound(msg, ctx); err != nil {
   339  		return fmt.Errorf("messenger HandleInbound: %w", err)
   340  	}
   341  
   342  	_, err := svc.HandleInbound(msg, ctx)
   343  
   344  	return err
   345  }