github.com/hyperledger/aries-framework-go@v0.3.2/pkg/client/messaging/client.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package messaging
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"errors"
    13  	"fmt"
    14  
    15  	"github.com/google/uuid"
    16  
    17  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    18  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    19  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher"
    20  	"github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr"
    21  	"github.com/hyperledger/aries-framework-go/pkg/kms"
    22  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    23  	"github.com/hyperledger/aries-framework-go/pkg/vdr/fingerprint"
    24  	"github.com/hyperledger/aries-framework-go/spi/storage"
    25  )
    26  
    27  const (
    28  	// errors.
    29  	errMsgDestinationMissing = "missing message destination"
    30  )
    31  
    32  var logger = log.New("aries-framework/client/messaging")
    33  
    34  // provider contains dependencies for the message client and is typically created by using aries.Context().
    35  type provider interface {
    36  	VDRegistry() vdr.Registry
    37  	Messenger() service.Messenger
    38  	ProtocolStateStorageProvider() storage.Provider
    39  	StorageProvider() storage.Provider
    40  	KMS() kms.KeyManager
    41  }
    42  
    43  // MessageHandler maintains registered message services
    44  // and it allows dynamic registration of message services.
    45  type MessageHandler interface {
    46  	// Services returns list of available message services in this message handler
    47  	Services() []dispatcher.MessageService
    48  	// Register registers given message services to this message handler
    49  	Register(msgSvcs ...dispatcher.MessageService) error
    50  	// Unregister unregisters message service with given name from this message handler
    51  	Unregister(name string) error
    52  }
    53  
    54  // Notifier represents a notification dispatcher.
    55  type Notifier interface {
    56  	Notify(topic string, message []byte) error
    57  }
    58  
    59  type sendMsgOpts struct {
    60  	// Connection ID of the message destination
    61  	// This parameter takes precedence over all the other destination parameters.
    62  	connectionID string
    63  
    64  	// DID of the destination.
    65  	// This parameter takes precedence over `ServiceEndpoint` destination parameter.
    66  	theirDID string
    67  
    68  	// Destination is service endpoint destination.
    69  	// This param can be used to send messages outside connection.
    70  	destination *service.Destination
    71  
    72  	// Message type of the response for the message sent.
    73  	// If provided then messenger will wait for the response of this type after sending message.
    74  	responseMsgType string
    75  
    76  	// context for await reply operation.
    77  	waitForResponseCtx context.Context
    78  }
    79  
    80  // SendMessageOpions is the options for choosing message destinations.
    81  type SendMessageOpions func(opts *sendMsgOpts)
    82  
    83  // SendByConnectionID option to choose message destination by connection ID.
    84  func SendByConnectionID(connectionID string) SendMessageOpions {
    85  	return func(opts *sendMsgOpts) {
    86  		opts.connectionID = connectionID
    87  	}
    88  }
    89  
    90  // SendByTheirDID option to choose message destination by connection ID.
    91  func SendByTheirDID(theirDID string) SendMessageOpions {
    92  	return func(opts *sendMsgOpts) {
    93  		opts.theirDID = theirDID
    94  	}
    95  }
    96  
    97  // SendByDestination option to set message destination.
    98  func SendByDestination(destination *service.Destination) SendMessageOpions {
    99  	return func(opts *sendMsgOpts) {
   100  		opts.destination = destination
   101  	}
   102  }
   103  
   104  // WaitForResponse option to set message response type.
   105  // Message reply will wait for the response of this message type and matching thread ID.
   106  func WaitForResponse(ctx context.Context, responseType string) SendMessageOpions {
   107  	return func(opts *sendMsgOpts) {
   108  		opts.waitForResponseCtx = ctx
   109  		opts.responseMsgType = responseType
   110  	}
   111  }
   112  
   113  // messageDispatcher is message dispatch action which returns id of the message sent or error if it fails.
   114  type messageDispatcher func() error
   115  
   116  // Client enable access to messaging features.
   117  type Client struct {
   118  	ctx              provider
   119  	msgRegistrar     MessageHandler
   120  	notifier         Notifier
   121  	connectionLookup *connection.Lookup
   122  }
   123  
   124  // New return new instance of message client.
   125  func New(ctx provider, registrar MessageHandler, notifier Notifier) (*Client, error) {
   126  	connectionLookup, err := connection.NewLookup(ctx)
   127  	if err != nil {
   128  		return nil, fmt.Errorf("failed to initialize connection lookup : %w", err)
   129  	}
   130  
   131  	c := &Client{
   132  		ctx:              ctx,
   133  		msgRegistrar:     registrar,
   134  		connectionLookup: connectionLookup,
   135  		notifier:         notifier,
   136  	}
   137  
   138  	return c, nil
   139  }
   140  
   141  // RegisterService registers new message service to message handler registrar.
   142  func (c *Client) RegisterService(name, msgType string, purpose ...string) error {
   143  	return c.msgRegistrar.Register(newMessageService(name, msgType, purpose, c.notifier))
   144  }
   145  
   146  // UnregisterService unregisters given message service handler registrar.
   147  func (c *Client) UnregisterService(name string) error {
   148  	return c.msgRegistrar.Unregister(name)
   149  }
   150  
   151  // Services returns list of registered service names.
   152  func (c *Client) Services() []string {
   153  	names := []string{}
   154  	for _, svc := range c.msgRegistrar.Services() {
   155  		names = append(names, svc.Name())
   156  	}
   157  
   158  	return names
   159  }
   160  
   161  // Send sends new message based on destination options provided.
   162  func (c *Client) Send(msg json.RawMessage, opts ...SendMessageOpions) (json.RawMessage, error) {
   163  	sendOpts := &sendMsgOpts{}
   164  
   165  	for _, opt := range opts {
   166  		opt(sendOpts)
   167  	}
   168  
   169  	var action messageDispatcher
   170  
   171  	didCommMsg, err := prepareMessage(msg)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	switch {
   177  	case sendOpts.connectionID != "":
   178  		action, err = c.sendToConnection(didCommMsg, sendOpts.connectionID)
   179  	case sendOpts.theirDID != "":
   180  		action, err = c.sendToTheirDID(didCommMsg, sendOpts.theirDID)
   181  	case sendOpts.destination != nil:
   182  		action, err = c.sendToDestination(didCommMsg, sendOpts.destination)
   183  	default:
   184  		return nil, fmt.Errorf(errMsgDestinationMissing)
   185  	}
   186  
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  
   191  	return c.sendAndWaitForReply(sendOpts.waitForResponseCtx, action, didCommMsg.ID(), sendOpts.responseMsgType)
   192  }
   193  
   194  // Reply sends reply to existing message.
   195  func (c *Client) Reply(ctx context.Context, msg json.RawMessage, msgID string, startNewThread bool,
   196  	waitForResponse string) (json.RawMessage, error) {
   197  	var action messageDispatcher
   198  
   199  	didCommMsg, err := prepareMessage(msg)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	if startNewThread {
   205  		action = func() error {
   206  			return c.ctx.Messenger().ReplyToNested(didCommMsg, &service.NestedReplyOpts{MsgID: msgID})
   207  		}
   208  
   209  		return c.sendAndWaitForReply(ctx, action, didCommMsg.ID(), waitForResponse)
   210  	}
   211  
   212  	action = func() error {
   213  		return c.ctx.Messenger().ReplyTo(msgID, didCommMsg) // nolint: staticcheck
   214  	}
   215  
   216  	return c.sendAndWaitForReply(ctx, action, "", waitForResponse)
   217  }
   218  
   219  func (c *Client) sendToConnection(msg service.DIDCommMsgMap, connectionID string) (messageDispatcher, error) {
   220  	conn, err := c.connectionLookup.GetConnectionRecord(connectionID)
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  
   225  	return func() error {
   226  		return c.ctx.Messenger().Send(msg, conn.MyDID, conn.TheirDID)
   227  	}, nil
   228  }
   229  
   230  func (c *Client) sendToTheirDID(msg service.DIDCommMsgMap, theirDID string) (messageDispatcher, error) {
   231  	conn, err := c.connectionLookup.GetConnectionRecordByTheirDID(theirDID)
   232  	if err == nil {
   233  		return func() error {
   234  			return c.ctx.Messenger().Send(msg, conn.MyDID, conn.TheirDID)
   235  		}, nil
   236  	} else if !errors.Is(err, storage.ErrDataNotFound) {
   237  		return nil, err
   238  	}
   239  
   240  	dest, err := service.GetDestination(theirDID, c.ctx.VDRegistry())
   241  	if err != nil {
   242  		return nil, err
   243  	}
   244  
   245  	return c.sendToDestination(msg, dest)
   246  }
   247  
   248  func (c *Client) sendToDestination(msg service.DIDCommMsgMap, dest *service.Destination) (messageDispatcher, error) {
   249  	_, sigPubKey, err := c.ctx.KMS().CreateAndExportPubKeyBytes(kms.ED25519Type)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  
   254  	didKey, _ := fingerprint.CreateDIDKey(sigPubKey)
   255  
   256  	return func() error {
   257  		return c.ctx.Messenger().SendToDestination(msg, didKey, dest)
   258  	}, nil
   259  }
   260  
   261  func (c *Client) sendAndWaitForReply(ctx context.Context, action messageDispatcher, thID string,
   262  	replyType string) (json.RawMessage, error) {
   263  	var notificationCh chan NotificationPayload
   264  
   265  	if replyType != "" {
   266  		topic := uuid.New().String()
   267  		notificationCh = make(chan NotificationPayload)
   268  
   269  		err := c.msgRegistrar.Register(newMessageService(topic, replyType, nil,
   270  			NewNotifier(notificationCh, func(topic string, msgBytes []byte) bool {
   271  				var message struct {
   272  					Message service.DIDCommMsgMap `json:"message"`
   273  				}
   274  
   275  				err := json.Unmarshal(msgBytes, &message)
   276  				if err != nil {
   277  					logger.Debugf("failed to unmarshal incoming message reply: %s", err)
   278  					return false
   279  				}
   280  
   281  				msgThID, err := message.Message.ThreadID()
   282  				if err != nil {
   283  					logger.Debugf("failed to read incoming message reply thread ID: %s", err)
   284  					return false
   285  				}
   286  
   287  				return thID == "" || thID == msgThID
   288  			})))
   289  		if err != nil {
   290  			return nil, err
   291  		}
   292  
   293  		defer func() {
   294  			e := c.msgRegistrar.Unregister(topic)
   295  			if e != nil {
   296  				logger.Warnf("Failed to unregister wait for reply notifier: %w", e)
   297  			}
   298  		}()
   299  	}
   300  
   301  	err := action()
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  
   306  	if notificationCh != nil {
   307  		return waitForResponse(ctx, notificationCh)
   308  	}
   309  
   310  	return json.RawMessage{}, nil
   311  }
   312  
   313  func waitForResponse(ctx context.Context, notificationCh chan NotificationPayload) (json.RawMessage, error) {
   314  	select {
   315  	case payload := <-notificationCh:
   316  		return json.RawMessage(payload.Raw), nil
   317  
   318  	case <-ctx.Done():
   319  		return nil, fmt.Errorf("failed to get reply, context deadline exceeded")
   320  	}
   321  }
   322  
   323  func prepareMessage(msg json.RawMessage) (service.DIDCommMsgMap, error) {
   324  	didCommMsg, err := service.ParseDIDCommMsgMap(msg)
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	if didCommMsg.ID() == "" {
   330  		didCommMsg.SetID(uuid.New().String())
   331  	}
   332  
   333  	return didCommMsg, nil
   334  }