github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/dispatcher/outbound/outbound.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package outbound
     8  
     9  import (
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"strings"
    14  
    15  	"github.com/google/uuid"
    16  
    17  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    18  	commonmodel "github.com/hyperledger/aries-framework-go/pkg/common/model"
    19  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/middleware"
    20  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/model"
    21  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    22  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    23  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    24  	"github.com/hyperledger/aries-framework-go/pkg/doc/util/kmsdidkey"
    25  	"github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr"
    26  	"github.com/hyperledger/aries-framework-go/pkg/kms"
    27  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    28  	"github.com/hyperledger/aries-framework-go/spi/storage"
    29  )
    30  
    31  // provider interface for outbound ctx.
    32  type provider interface {
    33  	Packager() transport.Packager
    34  	OutboundTransports() []transport.OutboundTransport
    35  	TransportReturnRoute() string
    36  	VDRegistry() vdr.Registry
    37  	KMS() kms.KeyManager
    38  	KeyAgreementType() kms.KeyType
    39  	ProtocolStateStorageProvider() storage.Provider
    40  	StorageProvider() storage.Provider
    41  	MediaTypeProfiles() []string
    42  	DIDRotator() *middleware.DIDCommMessageMiddleware
    43  }
    44  
    45  type connectionLookup interface {
    46  	GetConnectionIDByDIDs(myDID, theirDID string) (string, error)
    47  	GetConnectionRecord(string) (*connection.Record, error)
    48  	GetConnectionRecordByDIDs(myDID, theirDID string) (*connection.Record, error)
    49  }
    50  
    51  type connectionRecorder interface {
    52  	connectionLookup
    53  	SaveConnectionRecord(record *connection.Record) error
    54  }
    55  
    56  // Dispatcher dispatch msgs to destination.
    57  type Dispatcher struct {
    58  	outboundTransports   []transport.OutboundTransport
    59  	packager             transport.Packager
    60  	transportReturnRoute string
    61  	vdRegistry           vdr.Registry
    62  	kms                  kms.KeyManager
    63  	keyAgreementType     kms.KeyType
    64  	connections          connectionRecorder
    65  	mediaTypeProfiles    []string
    66  	didcommV2Handler     *middleware.DIDCommMessageMiddleware
    67  }
    68  
    69  // legacyForward is DIDComm V1 route Forward msg as declared in
    70  // https://github.com/hyperledger/aries-rfcs/blob/main/concepts/0094-cross-domain-messaging/README.md
    71  type legacyForward struct {
    72  	Type string          `json:"@type,omitempty"`
    73  	ID   string          `json:"@id,omitempty"`
    74  	To   string          `json:"to,omitempty"`
    75  	Msg  *model.Envelope `json:"msg,omitempty"`
    76  }
    77  
    78  var logger = log.New("aries-framework/didcomm/dispatcher")
    79  
    80  // NewOutbound return new dispatcher outbound instance.
    81  func NewOutbound(prov provider) (*Dispatcher, error) {
    82  	o := &Dispatcher{
    83  		outboundTransports:   prov.OutboundTransports(),
    84  		packager:             prov.Packager(),
    85  		transportReturnRoute: prov.TransportReturnRoute(),
    86  		vdRegistry:           prov.VDRegistry(),
    87  		kms:                  prov.KMS(),
    88  		keyAgreementType:     prov.KeyAgreementType(),
    89  		mediaTypeProfiles:    prov.MediaTypeProfiles(),
    90  		didcommV2Handler:     prov.DIDRotator(),
    91  	}
    92  
    93  	var err error
    94  
    95  	o.connections, err = connection.NewRecorder(prov)
    96  	if err != nil {
    97  		return nil, fmt.Errorf("failed to init connection recorder: %w", err)
    98  	}
    99  
   100  	return o, nil
   101  }
   102  
   103  // SendToDID sends a message from myDID to the agent who owns theirDID.
   104  func (o *Dispatcher) SendToDID(msg interface{}, myDID, theirDID string) error { // nolint:funlen,gocyclo,gocognit
   105  	myDocResolution, err := o.vdRegistry.Resolve(myDID)
   106  	if err != nil {
   107  		return fmt.Errorf("failed to resolve my DID: %w", err)
   108  	}
   109  
   110  	theirDocResolution, err := o.vdRegistry.Resolve(theirDID)
   111  	if err != nil {
   112  		return fmt.Errorf("failed to resolve their DID: %w", err)
   113  	}
   114  
   115  	var connectionVersion service.Version
   116  
   117  	didcommMsg, isMsgMap := msg.(service.DIDCommMsgMap)
   118  
   119  	var isV2 bool
   120  
   121  	if isMsgMap {
   122  		isV2, err = service.IsDIDCommV2(&didcommMsg)
   123  		if err == nil && isV2 {
   124  			connectionVersion = service.V2
   125  		} else {
   126  			connectionVersion = service.V1
   127  		}
   128  	}
   129  
   130  	connRec, err := o.getOrCreateConnection(myDID, theirDID, connectionVersion)
   131  	if err != nil {
   132  		return fmt.Errorf("failed to fetch connection record: %w", err)
   133  	}
   134  
   135  	var sendWithAnoncrypt bool
   136  
   137  	if isMsgMap { // nolint:nestif
   138  		didcommMsg = o.didcommV2Handler.HandleOutboundMessage(didcommMsg, connRec)
   139  
   140  		if connRec.PeerDIDInitialState != "" {
   141  			// we need to use anoncrypt if myDID is a peer DID being shared with the recipient through this message.
   142  			sendWithAnoncrypt = true
   143  		}
   144  
   145  		// the first message sent using didcomm v2 should contain the invitation ID as pthid
   146  		if connRec.DIDCommVersion == service.V2 && connRec.ParentThreadID != "" && connectionVersion == service.V2 {
   147  			pthid, hasPthid := didcommMsg["pthid"].(string)
   148  
   149  			thid, e := didcommMsg.ThreadID()
   150  			if e == nil && didcommMsg.ID() == thid && (!hasPthid || pthid == "") {
   151  				didcommMsg["pthid"] = connRec.ParentThreadID
   152  			}
   153  		}
   154  
   155  		msg = &didcommMsg
   156  	} else {
   157  		didcommMsgPtr, ok := msg.(*service.DIDCommMsgMap)
   158  		if ok {
   159  			didcommMsg = *didcommMsgPtr
   160  		} else {
   161  			didcommMsg = service.NewDIDCommMsgMap(msg)
   162  			msg = &didcommMsg
   163  		}
   164  	}
   165  
   166  	dest, err := service.CreateDestination(theirDocResolution.DIDDocument)
   167  	if err != nil {
   168  		return fmt.Errorf(
   169  			"outboundDispatcher.SendToDID failed to get didcomm destination for theirDID [%s]: %w", theirDID, err)
   170  	}
   171  
   172  	if len(connRec.MediaTypeProfiles) > 0 {
   173  		dest.MediaTypeProfiles = make([]string, len(connRec.MediaTypeProfiles))
   174  		copy(dest.MediaTypeProfiles, connRec.MediaTypeProfiles)
   175  	}
   176  
   177  	mtp := o.mediaTypeProfile(dest)
   178  	switch mtp {
   179  	case transport.MediaTypeV1PlaintextPayload, transport.MediaTypeV1EncryptedEnvelope,
   180  		transport.MediaTypeRFC0019EncryptedEnvelope, transport.MediaTypeAIP2RFC0019Profile:
   181  		sendWithAnoncrypt = false
   182  	}
   183  
   184  	if sendWithAnoncrypt {
   185  		return o.Send(msg, "", dest)
   186  	}
   187  
   188  	src, err := service.CreateDestination(myDocResolution.DIDDocument)
   189  	if err != nil {
   190  		return fmt.Errorf("outboundDispatcher.SendToDID failed to get didcomm destination for myDID [%s]: %w", myDID, err)
   191  	}
   192  
   193  	// We get at least one recipient key, so we can use the first one
   194  	//  (right now, with only one key type used for sending)
   195  	key := src.RecipientKeys[0]
   196  
   197  	return o.Send(msg, key, dest)
   198  }
   199  
   200  func (o *Dispatcher) defaultMediaTypeProfiles() []string {
   201  	mediaTypes := make([]string, len(o.mediaTypeProfiles))
   202  	copy(mediaTypes, o.mediaTypeProfiles)
   203  
   204  	return mediaTypes
   205  }
   206  
   207  // getOrCreateConnection returns true iff it created a new connection rather than fetching one.
   208  func (o *Dispatcher) getOrCreateConnection(myDID, theirDID string, connectionVersion service.Version,
   209  ) (*connection.Record, error) {
   210  	record, err := o.connections.GetConnectionRecordByDIDs(myDID, theirDID)
   211  	if err == nil {
   212  		return record, nil
   213  	} else if !errors.Is(err, storage.ErrDataNotFound) {
   214  		return nil, fmt.Errorf("failed to check if connection exists: %w", err)
   215  	}
   216  
   217  	// myDID and theirDID never had a connection, create a default connection for OOBless communication.
   218  	logger.Debugf("no connection record found for myDID=%s theirDID=%s, will create", myDID, theirDID)
   219  
   220  	newRecord := connection.Record{
   221  		ConnectionID:   uuid.New().String(),
   222  		MyDID:          myDID,
   223  		TheirDID:       theirDID,
   224  		State:          connection.StateNameCompleted,
   225  		Namespace:      connection.MyNSPrefix,
   226  		DIDCommVersion: connectionVersion,
   227  	}
   228  
   229  	if connectionVersion == service.V2 {
   230  		newRecord.ServiceEndPoint = commonmodel.NewDIDCommV2Endpoint(
   231  			[]commonmodel.DIDCommV2Endpoint{{Accept: o.defaultMediaTypeProfiles()}})
   232  	} else {
   233  		newRecord.MediaTypeProfiles = o.defaultMediaTypeProfiles()
   234  	}
   235  
   236  	err = o.connections.SaveConnectionRecord(&newRecord)
   237  	if err != nil {
   238  		return nil, fmt.Errorf("failed to save new connection: %w", err)
   239  	}
   240  
   241  	return &newRecord, nil
   242  }
   243  
   244  // Send sends the message after packing with the sender key and recipient keys.
   245  func (o *Dispatcher) Send(msg interface{}, senderKey string, des *service.Destination) error { // nolint:funlen,gocyclo
   246  	// check if outbound accepts routing keys, else use recipient keys
   247  	keys := des.RecipientKeys
   248  	if routingKeys, err := des.ServiceEndpoint.RoutingKeys(); err == nil && len(routingKeys) > 0 { // DIDComm V2
   249  		keys = routingKeys
   250  	} else if len(des.RoutingKeys) > 0 { // DIDComm V1
   251  		keys = routingKeys
   252  	}
   253  
   254  	var outboundTransport transport.OutboundTransport
   255  
   256  	for _, v := range o.outboundTransports {
   257  		uri, err := des.ServiceEndpoint.URI()
   258  		if err != nil {
   259  			logger.Debugf("destination ServiceEndpoint empty: %w, it will not be checked", err)
   260  		}
   261  
   262  		if v.AcceptRecipient(keys) || v.Accept(uri) {
   263  			outboundTransport = v
   264  			break
   265  		}
   266  	}
   267  
   268  	if outboundTransport == nil {
   269  		return fmt.Errorf("outboundDispatcher.Send: no transport found for destination: %+v", des)
   270  	}
   271  
   272  	req, err := json.Marshal(msg)
   273  	if err != nil {
   274  		return fmt.Errorf("outboundDispatcher.Send: failed marshal to bytes: %w", err)
   275  	}
   276  
   277  	// update the outbound message with transport return route option [all or thread]
   278  	req, err = o.addTransportRouteOptions(req, des)
   279  	if err != nil {
   280  		return fmt.Errorf("outboundDispatcher.Send: failed to add transport route options: %w", err)
   281  	}
   282  
   283  	mtp := o.mediaTypeProfile(des)
   284  
   285  	var fromKey []byte
   286  
   287  	if len(senderKey) > 0 {
   288  		fromKey = []byte(senderKey)
   289  	}
   290  
   291  	packedMsg, err := o.packager.PackMessage(&transport.Envelope{
   292  		MediaTypeProfile: mtp,
   293  		Message:          req,
   294  		FromKey:          fromKey,
   295  		ToKeys:           des.RecipientKeys,
   296  	})
   297  	if err != nil {
   298  		return fmt.Errorf("outboundDispatcher.Send: failed to pack msg: %w", err)
   299  	}
   300  
   301  	// set the return route option
   302  	des.TransportReturnRoute = o.transportReturnRoute
   303  
   304  	packedMsg, err = o.createForwardMessage(packedMsg, des)
   305  	if err != nil {
   306  		return fmt.Errorf("outboundDispatcher.Send: failed to create forward msg: %w", err)
   307  	}
   308  
   309  	_, err = outboundTransport.Send(packedMsg, des)
   310  	if err != nil {
   311  		return fmt.Errorf("outboundDispatcher.Send: failed to send msg using outbound transport: %w", err)
   312  	}
   313  
   314  	return nil
   315  }
   316  
   317  // Forward forwards the message without packing to the destination.
   318  func (o *Dispatcher) Forward(msg interface{}, des *service.Destination) error {
   319  	var (
   320  		uri string
   321  		err error
   322  	)
   323  
   324  	uri, err = des.ServiceEndpoint.URI()
   325  	if err != nil {
   326  		logger.Debugf("destination serviceEndpoint forward URI is not set: %w, will skip value", err)
   327  	}
   328  
   329  	for _, v := range o.outboundTransports {
   330  		if !v.AcceptRecipient(des.RecipientKeys) {
   331  			if !v.Accept(uri) {
   332  				continue
   333  			}
   334  		}
   335  
   336  		req, err := json.Marshal(msg)
   337  		if err != nil {
   338  			return fmt.Errorf("outboundDispatcher.Forward: failed marshal to bytes: %w", err)
   339  		}
   340  
   341  		_, err = v.Send(req, des)
   342  		if err != nil {
   343  			return fmt.Errorf("outboundDispatcher.Forward: failed to send msg using outbound transport: %w", err)
   344  		}
   345  
   346  		return nil
   347  	}
   348  
   349  	return fmt.Errorf("outboundDispatcher.Forward: no transport found for serviceEndpoint: %s", uri)
   350  }
   351  
   352  func (o *Dispatcher) createForwardMessage(msg []byte, des *service.Destination) ([]byte, error) {
   353  	mtProfile := o.mediaTypeProfile(des)
   354  
   355  	var (
   356  		forwardMsgType string
   357  		err            error
   358  	)
   359  
   360  	switch mtProfile {
   361  	case transport.MediaTypeV2EncryptedEnvelopeV1PlaintextPayload, transport.MediaTypeV2EncryptedEnvelope,
   362  		transport.MediaTypeAIP2RFC0587Profile, transport.MediaTypeV2PlaintextPayload, transport.MediaTypeDIDCommV2Profile:
   363  		// for DIDComm V2, do not set senderKey to force Anoncrypt packing. Only set the V2 forwardMsgType.
   364  		forwardMsgType = service.ForwardMsgTypeV2
   365  	default: // default is DIDComm V1
   366  		forwardMsgType = service.ForwardMsgType
   367  	}
   368  
   369  	routingKeys, err := des.ServiceEndpoint.RoutingKeys()
   370  	if err != nil {
   371  		logger.Debugf("des.ServiceEndpoint.RoutingKeys() (didcomm v2) returned an error %w, "+
   372  			"will check routinKeys (didcomm v1) array", err)
   373  	}
   374  
   375  	if len(routingKeys) == 0 {
   376  		if len(des.RoutingKeys) == 0 {
   377  			return msg, nil
   378  		}
   379  
   380  		routingKeys = des.RoutingKeys
   381  	}
   382  
   383  	fwdKeys := append([]string{des.RecipientKeys[0]}, routingKeys...)
   384  
   385  	packedMsg, err := o.createPackedNestedForwards(msg, fwdKeys, forwardMsgType, mtProfile)
   386  	if err != nil {
   387  		return nil, fmt.Errorf("failed to create packed nested forwards: %w", err)
   388  	}
   389  
   390  	return packedMsg, nil
   391  }
   392  
   393  func (o *Dispatcher) createPackedNestedForwards(msg []byte, routingKeys []string, fwdMsgType, mtProfile string) ([]byte, error) { //nolint: lll
   394  	for i, key := range routingKeys {
   395  		if i+1 >= len(routingKeys) {
   396  			break
   397  		}
   398  		// create forward message
   399  		forward := model.Forward{
   400  			Type: fwdMsgType,
   401  			ID:   uuid.New().String(),
   402  			To:   key,
   403  			Msg:  msg,
   404  		}
   405  
   406  		var err error
   407  
   408  		msg, err = o.packForward(forward, []string{routingKeys[i+1]}, mtProfile)
   409  		if err != nil {
   410  			return nil, fmt.Errorf("failed to pack forward msg: %w", err)
   411  		}
   412  	}
   413  
   414  	return msg, nil
   415  }
   416  
   417  func (o *Dispatcher) packForward(fwd model.Forward, toKeys []string, mtProfile string) ([]byte, error) {
   418  	env := &model.Envelope{}
   419  
   420  	var (
   421  		forward interface{}
   422  		err     error
   423  		req     []byte
   424  	)
   425  	// try to convert msg to Envelope
   426  	err = json.Unmarshal(fwd.Msg, env)
   427  	if err == nil {
   428  		// Convert did:key to base58 to support legacy profile type
   429  		if strings.HasPrefix(fwd.To, "did:key") && mtProfile == transport.LegacyDIDCommV1Profile {
   430  			fwd.To, err = kmsdidkey.GetBase58PubKeyFromDIDKey(fwd.To)
   431  			if err != nil {
   432  				return nil, err
   433  			}
   434  		}
   435  		// create legacy forward
   436  		forward = legacyForward{
   437  			Type: fwd.Type,
   438  			ID:   fwd.ID,
   439  			To:   fwd.To,
   440  			Msg:  env,
   441  		}
   442  	} else {
   443  		forward = fwd
   444  	}
   445  	// convert forward message to bytes
   446  	req, err = json.Marshal(forward)
   447  	if err != nil {
   448  		return nil, fmt.Errorf("failed marshal to bytes: %w", err)
   449  	}
   450  
   451  	var packedMsg []byte
   452  	packedMsg, err = o.packager.PackMessage(&transport.Envelope{
   453  		MediaTypeProfile: mtProfile,
   454  		Message:          req,
   455  		FromKey:          []byte{},
   456  		ToKeys:           toKeys,
   457  	})
   458  
   459  	if err != nil {
   460  		return nil, fmt.Errorf("failed to pack forward msg: %w", err)
   461  	}
   462  
   463  	return packedMsg, nil
   464  }
   465  
   466  func (o *Dispatcher) addTransportRouteOptions(req []byte, des *service.Destination) ([]byte, error) {
   467  	// don't add transport route options for forward messages
   468  	if routingKeys, err := des.ServiceEndpoint.RoutingKeys(); err == nil && len(routingKeys) > 0 {
   469  		return req, nil
   470  	}
   471  
   472  	if o.transportReturnRoute == decorator.TransportReturnRouteAll ||
   473  		o.transportReturnRoute == decorator.TransportReturnRouteThread {
   474  		// create the decorator with the option set in the framework
   475  		transportDec := &decorator.Transport{ReturnRoute: &decorator.ReturnRoute{Value: o.transportReturnRoute}}
   476  
   477  		transportDecJSON, jsonErr := json.Marshal(transportDec)
   478  		if jsonErr != nil {
   479  			return nil, fmt.Errorf("json marshal : %w", jsonErr)
   480  		}
   481  
   482  		request := string(req)
   483  		index := strings.Index(request, "{")
   484  
   485  		// add transport route option decorator to the original request
   486  		req = []byte(request[:index+1] + string(transportDecJSON)[1:len(string(transportDecJSON))-1] + "," +
   487  			request[index+1:])
   488  	}
   489  
   490  	return req, nil
   491  }
   492  
   493  func (o *Dispatcher) mediaTypeProfile(des *service.Destination) string {
   494  	var (
   495  		mt     string
   496  		accept []string
   497  		err    error
   498  	)
   499  
   500  	if accept, err = des.ServiceEndpoint.Accept(); err != nil || len(accept) == 0 { // didcomm v2
   501  		accept = des.MediaTypeProfiles // didcomm v1
   502  	}
   503  
   504  	if len(accept) > 0 {
   505  		for _, mtp := range accept {
   506  			switch mtp {
   507  			case transport.MediaTypeV1PlaintextPayload, transport.MediaTypeRFC0019EncryptedEnvelope,
   508  				transport.MediaTypeAIP2RFC0019Profile, transport.MediaTypeProfileDIDCommAIP1,
   509  				transport.LegacyDIDCommV1Profile:
   510  				// overridable with higher priority media type.
   511  				if mt == "" {
   512  					mt = mtp
   513  				}
   514  			case transport.MediaTypeV1EncryptedEnvelope, transport.MediaTypeV2EncryptedEnvelopeV1PlaintextPayload,
   515  				transport.MediaTypeAIP2RFC0587Profile:
   516  				mt = mtp
   517  			case transport.MediaTypeV2EncryptedEnvelope, transport.MediaTypeV2PlaintextPayload,
   518  				transport.MediaTypeDIDCommV2Profile:
   519  				// V2 is the highest priority, if found use it directly.
   520  				return mtp
   521  			}
   522  		}
   523  	}
   524  
   525  	if mt == "" {
   526  		return o.defaultMediaTypeProfiles()[0]
   527  	}
   528  
   529  	return mt
   530  }