github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/messagepickup/service.go (about)

     1  /*
     2  Copyright Scoir Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package messagepickup
     8  
     9  import (
    10  	"encoding/json"
    11  	"fmt"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/google/uuid"
    16  	"github.com/pkg/errors"
    17  
    18  	"github.com/hyperledger/aries-framework-go/pkg/common/log"
    19  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    20  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher"
    21  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    22  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    23  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    24  	"github.com/hyperledger/aries-framework-go/spi/storage"
    25  )
    26  
    27  const (
    28  	// MessagePickup defines the protocol name.
    29  	MessagePickup = "messagepickup"
    30  	// Spec defines the protocol spec.
    31  	Spec = "https://didcomm.org/messagepickup/1.0/"
    32  	// StatusMsgType defines the protocol propose-credential message type.
    33  	StatusMsgType = Spec + "status"
    34  	// StatusRequestMsgType defines the protocol propose-credential message type.
    35  	StatusRequestMsgType = Spec + "status-request"
    36  	// BatchPickupMsgType defines the protocol offer-credential message type.
    37  	BatchPickupMsgType = Spec + "batch-pickup"
    38  	// BatchMsgType defines the protocol offer-credential message type.
    39  	BatchMsgType = Spec + "batch"
    40  	// NoopMsgType defines the protocol request-credential message type.
    41  	NoopMsgType = Spec + "noop"
    42  )
    43  
    44  const (
    45  	updateTimeout = 50 * time.Second
    46  
    47  	// Namespace is namespace of messagepickup store name.
    48  	Namespace = "mailbox"
    49  )
    50  
    51  // ErrConnectionNotFound connection not found error.
    52  var (
    53  	ErrConnectionNotFound = errors.New("connection not found")
    54  	logger                = log.New("aries-framework/messagepickup")
    55  )
    56  
    57  type provider interface {
    58  	OutboundDispatcher() dispatcher.Outbound
    59  	StorageProvider() storage.Provider
    60  	ProtocolStateStorageProvider() storage.Provider
    61  	InboundMessageHandler() transport.InboundMessageHandler
    62  	Packager() transport.Packager
    63  }
    64  
    65  type connections interface {
    66  	GetConnectionRecord(string) (*connection.Record, error)
    67  }
    68  
    69  // Service for the messagepickup protocol.
    70  type Service struct {
    71  	service.Action
    72  	service.Message
    73  	connectionLookup connections
    74  	outbound         dispatcher.Outbound
    75  	msgStore         storage.Store
    76  	packager         transport.Packager
    77  	msgHandler       transport.InboundMessageHandler
    78  	batchMap         map[string]chan Batch
    79  	batchMapLock     sync.RWMutex
    80  	statusMap        map[string]chan Status
    81  	statusMapLock    sync.RWMutex
    82  	inboxLock        sync.Mutex
    83  	initialized      bool
    84  }
    85  
    86  // New returns the messagepickup service.
    87  func New(prov provider) (*Service, error) {
    88  	svc := Service{}
    89  
    90  	err := svc.Initialize(prov)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	return &svc, nil
    96  }
    97  
    98  // Initialize initializes the Service. If Initialize succeeds, any further call is a no-op.
    99  func (s *Service) Initialize(p interface{}) error {
   100  	if s.initialized {
   101  		return nil
   102  	}
   103  
   104  	prov, ok := p.(provider)
   105  	if !ok {
   106  		return fmt.Errorf("expected provider of type `%T`, got type `%T`", provider(nil), p)
   107  	}
   108  
   109  	store, err := prov.StorageProvider().OpenStore(Namespace)
   110  	if err != nil {
   111  		return fmt.Errorf("open mailbox store : %w", err)
   112  	}
   113  
   114  	connectionLookup, err := connection.NewLookup(prov)
   115  	if err != nil {
   116  		return err
   117  	}
   118  
   119  	s.outbound = prov.OutboundDispatcher()
   120  	s.msgStore = store
   121  	s.connectionLookup = connectionLookup
   122  	s.packager = prov.Packager()
   123  	s.msgHandler = prov.InboundMessageHandler()
   124  	s.batchMap = make(map[string]chan Batch)
   125  	s.statusMap = make(map[string]chan Status)
   126  
   127  	s.initialized = true
   128  
   129  	return nil
   130  }
   131  
   132  // HandleInbound handles inbound message pick up messages.
   133  func (s *Service) HandleInbound(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) {
   134  	// perform action asynchronously
   135  	go func() {
   136  		var err error
   137  
   138  		switch msg.Type() {
   139  		case StatusMsgType:
   140  			err = s.handleStatus(msg)
   141  		case StatusRequestMsgType:
   142  			err = s.handleStatusRequest(msg, ctx.MyDID(), ctx.TheirDID())
   143  		case BatchPickupMsgType:
   144  			err = s.handleBatchPickup(msg, ctx.MyDID(), ctx.TheirDID())
   145  		case BatchMsgType:
   146  			err = s.handleBatch(msg)
   147  		case NoopMsgType:
   148  			err = s.handleNoop(msg)
   149  		}
   150  
   151  		if err != nil {
   152  			logger.Errorf("Error handling message: (%w)\n", err)
   153  		}
   154  	}()
   155  
   156  	return msg.ID(), nil
   157  }
   158  
   159  // HandleOutbound adherence to dispatcher.ProtocolService.
   160  func (s *Service) HandleOutbound(_ service.DIDCommMsg, _, _ string) (string, error) {
   161  	return "", errors.New("not implemented")
   162  }
   163  
   164  // Accept checks whether the service can handle the message type.
   165  func (s *Service) Accept(msgType string) bool {
   166  	switch msgType {
   167  	case BatchPickupMsgType, BatchMsgType, StatusRequestMsgType, StatusMsgType, NoopMsgType:
   168  		return true
   169  	}
   170  
   171  	return false
   172  }
   173  
   174  // Name of the service.
   175  func (s *Service) Name() string {
   176  	return MessagePickup
   177  }
   178  
   179  func (s *Service) handleStatus(msg service.DIDCommMsg) error {
   180  	// unmarshal the payload
   181  	statusMsg := &Status{}
   182  
   183  	err := msg.Decode(statusMsg)
   184  	if err != nil {
   185  		return fmt.Errorf("status message unmarshal: %w", err)
   186  	}
   187  
   188  	// check if there are any channels registered for the message ID
   189  	statusCh := s.getStatusCh(statusMsg.ID)
   190  	if statusCh != nil {
   191  		// invoke the channel for the incoming message
   192  		statusCh <- *statusMsg
   193  	}
   194  
   195  	return nil
   196  }
   197  
   198  func (s *Service) handleStatusRequest(msg service.DIDCommMsg, myDID, theirDID string) error {
   199  	s.inboxLock.Lock()
   200  	defer s.inboxLock.Unlock()
   201  
   202  	// unmarshal the payload
   203  	request := &StatusRequest{}
   204  
   205  	err := msg.Decode(request)
   206  	if err != nil {
   207  		return fmt.Errorf("status request message unmarshal: %w", err)
   208  	}
   209  
   210  	logger.Debugf("retrieving stored messages for %s\n", theirDID)
   211  
   212  	outbox, err := s.getInbox(theirDID)
   213  	if err != nil {
   214  		return fmt.Errorf("error in status request getting inbox: %w", err)
   215  	}
   216  
   217  	resp := &Status{
   218  		Type:              StatusMsgType,
   219  		ID:                msg.ID(),
   220  		MessageCount:      outbox.MessageCount,
   221  		DurationWaited:    int(time.Since(outbox.LastDeliveredTime).Seconds()),
   222  		LastAddedTime:     outbox.LastAddedTime,
   223  		LastDeliveredTime: outbox.LastDeliveredTime,
   224  		LastRemovedTime:   outbox.LastRemovedTime,
   225  		TotalSize:         outbox.TotalSize,
   226  		Thread: &decorator.Thread{
   227  			PID: request.Thread.ID,
   228  		},
   229  	}
   230  
   231  	msgBytes, err := json.Marshal(resp)
   232  	if err != nil {
   233  		return fmt.Errorf("marshal batch: %w", err)
   234  	}
   235  
   236  	msgMap, err := service.ParseDIDCommMsgMap(msgBytes)
   237  	if err != nil {
   238  		return fmt.Errorf("parse batch into didcomm msg map: %w", err)
   239  	}
   240  
   241  	return s.outbound.SendToDID(msgMap, myDID, theirDID)
   242  }
   243  
   244  func (s *Service) handleBatchPickup(msg service.DIDCommMsg, myDID, theirDID string) error {
   245  	s.inboxLock.Lock()
   246  	defer s.inboxLock.Unlock()
   247  
   248  	// unmarshal the payload
   249  	request := &BatchPickup{}
   250  
   251  	err := msg.Decode(request)
   252  	if err != nil {
   253  		return fmt.Errorf("batch pickup message unmarshal : %w", err)
   254  	}
   255  
   256  	outbox, err := s.getInbox(theirDID)
   257  	if err != nil {
   258  		return fmt.Errorf("batch pickup get inbox: %w", err)
   259  	}
   260  
   261  	msgs, err := outbox.DecodeMessages()
   262  	if err != nil {
   263  		return fmt.Errorf("batch pickup decode : %w", err)
   264  	}
   265  
   266  	end := len(msgs)
   267  	if request.BatchSize < end {
   268  		end = request.BatchSize
   269  	}
   270  
   271  	outbox.LastDeliveredTime = time.Now()
   272  	outbox.LastRemovedTime = time.Now()
   273  
   274  	err = outbox.EncodeMessages(msgs[end:])
   275  	if err != nil {
   276  		return fmt.Errorf("batch pickup encode: %w", err)
   277  	}
   278  
   279  	err = s.putInbox(theirDID, outbox)
   280  	if err != nil {
   281  		return fmt.Errorf("batch pick up put inbox: %w", err)
   282  	}
   283  
   284  	msgs = msgs[0:end]
   285  
   286  	batch := Batch{
   287  		Type:     BatchMsgType,
   288  		ID:       msg.ID(),
   289  		Messages: msgs,
   290  	}
   291  
   292  	msgBytes, err := json.Marshal(batch)
   293  	if err != nil {
   294  		return fmt.Errorf("marshal batch: %w", err)
   295  	}
   296  
   297  	msgMap, err := service.ParseDIDCommMsgMap(msgBytes)
   298  	if err != nil {
   299  		return fmt.Errorf("parse batch into didcomm msg map: %w", err)
   300  	}
   301  
   302  	return s.outbound.SendToDID(msgMap, myDID, theirDID)
   303  }
   304  
   305  func (s *Service) handleBatch(msg service.DIDCommMsg) error {
   306  	// unmarshal the payload
   307  	batchMsg := &Batch{}
   308  
   309  	err := msg.Decode(batchMsg)
   310  	if err != nil {
   311  		return fmt.Errorf("batch message unmarshal : %w", err)
   312  	}
   313  
   314  	// check if there are any channels registered for the message ID
   315  	batchCh := s.getBatchCh(batchMsg.ID)
   316  
   317  	if batchCh != nil {
   318  		// invoke the channel for the incoming message
   319  		batchCh <- *batchMsg
   320  	}
   321  
   322  	return nil
   323  }
   324  
   325  func (s *Service) handleNoop(msg service.DIDCommMsg) error {
   326  	// unmarshal the payload
   327  	request := &Noop{}
   328  
   329  	err := msg.Decode(request)
   330  	if err != nil {
   331  		return fmt.Errorf("noop message unmarshal : %w", err)
   332  	}
   333  
   334  	return nil
   335  }
   336  
   337  type inbox struct {
   338  	DID               string          `json:"DID"`
   339  	MessageCount      int             `json:"message_count"`
   340  	LastAddedTime     time.Time       `json:"last_added_time,omitempty"`
   341  	LastDeliveredTime time.Time       `json:"last_delivered_time,omitempty"`
   342  	LastRemovedTime   time.Time       `json:"last_removed_time,omitempty"`
   343  	TotalSize         int             `json:"total_size,omitempty"`
   344  	Messages          json.RawMessage `json:"messages"`
   345  }
   346  
   347  // DecodeMessages Messages.
   348  func (r *inbox) DecodeMessages() ([]*Message, error) {
   349  	var out []*Message
   350  
   351  	var err error
   352  
   353  	if r.Messages != nil {
   354  		err = json.Unmarshal(r.Messages, &out)
   355  	}
   356  
   357  	return out, err
   358  }
   359  
   360  // EncodeMessages Messages.
   361  func (r *inbox) EncodeMessages(msg []*Message) error {
   362  	d, err := json.Marshal(msg)
   363  	if err != nil {
   364  		return fmt.Errorf("unable to marshal: %w", err)
   365  	}
   366  
   367  	r.Messages = d
   368  	r.MessageCount = len(msg)
   369  	r.TotalSize = len(d)
   370  
   371  	return nil
   372  }
   373  
   374  // AddMessage add message to inbox.
   375  func (s *Service) AddMessage(message []byte, theirDID string) error {
   376  	s.inboxLock.Lock()
   377  	defer s.inboxLock.Unlock()
   378  
   379  	outbox, err := s.createInbox(theirDID)
   380  	if err != nil {
   381  		return fmt.Errorf("unable to pull messages: %w", err)
   382  	}
   383  
   384  	msgs, err := outbox.DecodeMessages()
   385  	if err != nil {
   386  		return fmt.Errorf("unable to decode messages: %w", err)
   387  	}
   388  
   389  	m := Message{
   390  		ID:        uuid.New().String(),
   391  		AddedTime: time.Now(),
   392  		Message:   message,
   393  	}
   394  
   395  	msgs = append(msgs, &m)
   396  
   397  	outbox.LastDeliveredTime = time.Now()
   398  	outbox.LastRemovedTime = outbox.LastDeliveredTime
   399  
   400  	err = outbox.EncodeMessages(msgs)
   401  	if err != nil {
   402  		return fmt.Errorf("unable to encode messages: %w", err)
   403  	}
   404  
   405  	err = s.putInbox(theirDID, outbox)
   406  	if err != nil {
   407  		return fmt.Errorf("unable to put messages: %w", err)
   408  	}
   409  
   410  	return nil
   411  }
   412  
   413  func (s *Service) createInbox(theirDID string) (*inbox, error) {
   414  	msgs, err := s.getInbox(theirDID)
   415  	if err != nil && errors.Is(err, storage.ErrDataNotFound) {
   416  		msgs = &inbox{DID: theirDID}
   417  
   418  		msgBytes, e := json.Marshal(msgs)
   419  		if e != nil {
   420  			return nil, e
   421  		}
   422  
   423  		e = s.msgStore.Put(theirDID, msgBytes)
   424  		if e != nil {
   425  			return nil, e
   426  		}
   427  
   428  		return msgs, nil
   429  	}
   430  
   431  	return msgs, err
   432  }
   433  
   434  func (s *Service) getInbox(theirDID string) (*inbox, error) {
   435  	msgs := &inbox{DID: theirDID}
   436  
   437  	b, err := s.msgStore.Get(theirDID)
   438  	if err != nil {
   439  		return nil, err
   440  	}
   441  
   442  	err = json.Unmarshal(b, msgs)
   443  	if err != nil {
   444  		return nil, err
   445  	}
   446  
   447  	return msgs, nil
   448  }
   449  
   450  func (s *Service) putInbox(theirDID string, o *inbox) error {
   451  	b, err := json.Marshal(o)
   452  	if err != nil {
   453  		return err
   454  	}
   455  
   456  	return s.msgStore.Put(theirDID, b)
   457  }
   458  
   459  // StatusRequest request a status message.
   460  func (s *Service) StatusRequest(connectionID string) (*Status, error) {
   461  	// get the connection record for the ID to fetch DID information
   462  	conn, err := s.getConnection(connectionID)
   463  	if err != nil {
   464  		return nil, err
   465  	}
   466  
   467  	// generate message ID
   468  	msgID := uuid.New().String()
   469  
   470  	// register chan for callback processing
   471  	statusCh := make(chan Status)
   472  	s.setStatusCh(msgID, statusCh)
   473  
   474  	defer s.setStatusCh(msgID, nil)
   475  
   476  	// create request message
   477  	req := &StatusRequest{
   478  		Type: StatusRequestMsgType,
   479  		ID:   msgID,
   480  		Thread: &decorator.Thread{
   481  			PID: uuid.New().String(),
   482  		},
   483  	}
   484  
   485  	// send message to the router
   486  	if err := s.outbound.SendToDID(req, conn.MyDID, conn.TheirDID); err != nil {
   487  		return nil, fmt.Errorf("send route request: %w", err)
   488  	}
   489  
   490  	// callback processing (to make this function look like a sync function)
   491  	var sts *Status
   492  	select {
   493  	case s := <-statusCh:
   494  		sts = &s
   495  		// TODO https://github.com/hyperledger/aries-framework-go/issues/1134 configure this timeout at decorator level
   496  	case <-time.After(updateTimeout):
   497  		return nil, errors.New("timeout waiting for status request")
   498  	}
   499  
   500  	return sts, nil
   501  }
   502  
   503  // BatchPickup a request to have multiple waiting messages sent inside a batch message.
   504  func (s *Service) BatchPickup(connectionID string, size int) (int, error) {
   505  	// get the connection record for the ID to fetch DID information
   506  	conn, err := s.getConnection(connectionID)
   507  	if err != nil {
   508  		return -1, err
   509  	}
   510  
   511  	// generate message ID
   512  	msgID := uuid.New().String()
   513  
   514  	// register chan for callback processing
   515  	batchCh := make(chan Batch)
   516  	s.setBatchCh(msgID, batchCh)
   517  
   518  	defer s.setBatchCh(msgID, nil)
   519  
   520  	// create request message
   521  	req := &BatchPickup{
   522  		Type:      BatchPickupMsgType,
   523  		ID:        msgID,
   524  		BatchSize: size,
   525  	}
   526  
   527  	msgBytes, err := json.Marshal(req)
   528  	if err != nil {
   529  		return -1, fmt.Errorf("marshal req: %w", err)
   530  	}
   531  
   532  	msgMap, err := service.ParseDIDCommMsgMap(msgBytes)
   533  	if err != nil {
   534  		return -1, fmt.Errorf("parse req into didcomm msg map: %w", err)
   535  	}
   536  
   537  	// send message to the router
   538  	if err := s.outbound.SendToDID(msgMap, conn.MyDID, conn.TheirDID); err != nil {
   539  		return -1, fmt.Errorf("send batch pickup request: %w", err)
   540  	}
   541  
   542  	// callback processing (to make this function look like a sync function)
   543  	var processed int
   544  	select {
   545  	case batchResp := <-batchCh:
   546  		for _, msg := range batchResp.Messages {
   547  			err := s.handle(msg)
   548  			if err != nil {
   549  				logger.Errorf("error handling batch message %s: %w", msg.ID, err)
   550  
   551  				continue
   552  			}
   553  			processed++
   554  		}
   555  	// TODO https://github.com/hyperledger/aries-framework-go/issues/1134 configure this timeout at decorator level
   556  	case <-time.After(updateTimeout):
   557  		return -1, errors.New("timeout waiting for batch")
   558  	}
   559  
   560  	return processed, nil
   561  }
   562  
   563  // Noop a noop message.
   564  func (s *Service) Noop(connectionID string) error {
   565  	// get the connection record for the ID to fetch DID information
   566  	conn, err := s.getConnection(connectionID)
   567  	if err != nil {
   568  		return err
   569  	}
   570  
   571  	noop := &Noop{ID: uuid.New().String(), Type: NoopMsgType}
   572  
   573  	msgBytes, err := json.Marshal(noop)
   574  	if err != nil {
   575  		return fmt.Errorf("marshal noop: %w", err)
   576  	}
   577  
   578  	msgMap, err := service.ParseDIDCommMsgMap(msgBytes)
   579  	if err != nil {
   580  		return fmt.Errorf("parse noop into didcomm msg map: %w", err)
   581  	}
   582  
   583  	if err := s.outbound.SendToDID(msgMap, conn.MyDID, conn.TheirDID); err != nil {
   584  		return fmt.Errorf("send noop request: %w", err)
   585  	}
   586  
   587  	return nil
   588  }
   589  
   590  func (s *Service) getConnection(routerConnID string) (*connection.Record, error) {
   591  	conn, err := s.connectionLookup.GetConnectionRecord(routerConnID)
   592  	if err != nil {
   593  		if errors.Is(err, storage.ErrDataNotFound) {
   594  			return nil, ErrConnectionNotFound
   595  		}
   596  
   597  		return nil, fmt.Errorf("fetch connection record from store : %w", err)
   598  	}
   599  
   600  	return conn, nil
   601  }
   602  
   603  func (s *Service) getBatchCh(msgID string) chan Batch {
   604  	s.batchMapLock.RLock()
   605  	defer s.batchMapLock.RUnlock()
   606  
   607  	return s.batchMap[msgID]
   608  }
   609  
   610  func (s *Service) setBatchCh(msgID string, batchCh chan Batch) {
   611  	s.batchMapLock.Lock()
   612  	defer s.batchMapLock.Unlock()
   613  
   614  	if batchCh == nil {
   615  		delete(s.batchMap, msgID)
   616  	} else {
   617  		s.batchMap[msgID] = batchCh
   618  	}
   619  }
   620  
   621  func (s *Service) getStatusCh(msgID string) chan Status {
   622  	s.statusMapLock.RLock()
   623  	defer s.statusMapLock.RUnlock()
   624  
   625  	return s.statusMap[msgID]
   626  }
   627  
   628  func (s *Service) setStatusCh(msgID string, statusCh chan Status) {
   629  	s.statusMapLock.Lock()
   630  	defer s.statusMapLock.Unlock()
   631  
   632  	if statusCh == nil {
   633  		delete(s.statusMap, msgID)
   634  	} else {
   635  		s.statusMap[msgID] = statusCh
   636  	}
   637  }
   638  
   639  func (s *Service) handle(msg *Message) error {
   640  	unpackMsg, err := s.packager.UnpackMessage(msg.Message)
   641  	if err != nil {
   642  		return fmt.Errorf("failed to unpack msg: %w", err)
   643  	}
   644  
   645  	trans := &decorator.Transport{}
   646  	err = json.Unmarshal(unpackMsg.Message, trans)
   647  
   648  	if err != nil {
   649  		return fmt.Errorf("unmarshal transport decorator : %w", err)
   650  	}
   651  
   652  	messageHandler := s.msgHandler
   653  
   654  	err = messageHandler(unpackMsg)
   655  	if err != nil {
   656  		return fmt.Errorf("incoming msg processing failed: %w", err)
   657  	}
   658  
   659  	return nil
   660  }