github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/outofband/states.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package outofband
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  
    13  	"github.com/google/uuid"
    14  
    15  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service"
    16  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange"
    17  	"github.com/hyperledger/aries-framework-go/pkg/store/connection"
    18  )
    19  
    20  const (
    21  	// StateNameInitial is the initial state.
    22  	StateNameInitial = "initial"
    23  	// StateNameAwaitResponse is the state where a sender or a receiver are awaiting a response.
    24  	StateNameAwaitResponse = "await-response"
    25  	// StateNamePrepareResponse is the state where a receiver is preparing a response to the sender.
    26  	StateNamePrepareResponse = "prepare-response"
    27  	// StateNameDone is the final state.
    28  	StateNameDone = "done"
    29  
    30  	connectionRecordCompletedState = "completed"
    31  )
    32  
    33  type finisher func(service.Messenger) error
    34  
    35  func noAction(service.Messenger) error {
    36  	return nil
    37  }
    38  
    39  type dependencies struct {
    40  	connections           connectionRecorder
    41  	didSvc                didExchSvc
    42  	saveAttchStateFunc    func(*attachmentHandlingState) error
    43  	dispatchAttachmntFunc func(string, string, string) error
    44  }
    45  
    46  // The outofband protocol's state.
    47  type state interface {
    48  	Name() string
    49  	Execute(*context, *dependencies) (state, finisher, bool, error)
    50  }
    51  
    52  func stateFromName(n string) (state, error) {
    53  	states := []state{
    54  		&stateInitial{},
    55  		&stateAwaitResponse{},
    56  		&statePrepareResponse{},
    57  		&stateDone{},
    58  	}
    59  
    60  	for i := range states {
    61  		if states[i].Name() == n {
    62  			return states[i], nil
    63  		}
    64  	}
    65  
    66  	return nil, fmt.Errorf("unrecognized state name: %s", n)
    67  }
    68  
    69  func requiresApproval(msg service.DIDCommMsg) bool {
    70  	switch msg.Type() {
    71  	case InvitationMsgType, HandshakeReuseMsgType:
    72  		return true
    73  	}
    74  
    75  	return false
    76  }
    77  
    78  type stateInitial struct{}
    79  
    80  func (s *stateInitial) Name() string {
    81  	return StateNameInitial
    82  }
    83  
    84  func (s *stateInitial) Execute(ctx *context, _ *dependencies) (state, finisher, bool, error) {
    85  	if ctx.Inbound { // inbound invitation
    86  		return &statePrepareResponse{}, noAction, false, nil
    87  	}
    88  
    89  	// outbound invitation
    90  	return &stateAwaitResponse{}, func(m service.Messenger) error {
    91  		return m.Send(ctx.Msg, ctx.MyDID, ctx.TheirDID)
    92  	}, true, nil
    93  }
    94  
    95  type stateAwaitResponse struct{}
    96  
    97  func (s *stateAwaitResponse) Name() string {
    98  	return StateNameAwaitResponse
    99  }
   100  
   101  func (s *stateAwaitResponse) Execute(ctx *context, deps *dependencies) (state, finisher, bool, error) {
   102  	if !ctx.Inbound {
   103  		return nil, nil, true, fmt.Errorf("cannot execute '%s' for outbound messages", s.Name())
   104  	}
   105  
   106  	// inbound HandshakeReuse or HandshakeReuseAccepted
   107  	if ctx.Msg.Type() == HandshakeReuseMsgType {
   108  		return s.handleHandshakeReuse(ctx, deps)
   109  	}
   110  
   111  	return s.handleHandshakeReuseAccepted(ctx, deps)
   112  }
   113  
   114  func (s *stateAwaitResponse) handleHandshakeReuse(ctx *context, deps *dependencies) (state, finisher, bool, error) {
   115  	// incoming HandshakeReuse
   116  	logger.Debugf("handling %s with context: %+v", ctx.Msg.Type(), ctx)
   117  
   118  	connID, err := deps.connections.GetConnectionIDByDIDs(ctx.MyDID, ctx.TheirDID)
   119  	if err != nil {
   120  		return nil, nil, true, fmt.Errorf(
   121  			"failed to fetch connection ID [myDID=%s theirDID=%s]: %w",
   122  			ctx.MyDID, ctx.TheirDID, err,
   123  		)
   124  	}
   125  
   126  	record, err := deps.connections.GetConnectionRecord(connID)
   127  	if err != nil {
   128  		return nil, nil, true, fmt.Errorf("failed to fetch connection record [connID=%s]: %w", connID, err)
   129  	}
   130  
   131  	if record.State != connectionRecordCompletedState {
   132  		return nil, nil, true, fmt.Errorf(
   133  			"unexpected state for connection with ID=%s: expected '%s' got '%s'",
   134  			connID, connectionRecordCompletedState, record.State,
   135  		)
   136  	}
   137  
   138  	return &stateDone{}, func(m service.Messenger) error {
   139  		return m.ReplyToMsg(
   140  			ctx.Msg,
   141  			service.NewDIDCommMsgMap(&HandshakeReuseAccepted{
   142  				ID:   uuid.New().String(),
   143  				Type: HandshakeReuseAcceptedMsgType,
   144  			}),
   145  			ctx.MyDID,
   146  			ctx.TheirDID,
   147  		)
   148  	}, false, nil
   149  }
   150  
   151  func (s *stateAwaitResponse) handleHandshakeReuseAccepted(
   152  	ctx *context, deps *dependencies) (state, finisher, bool, error) {
   153  	logger.Debugf("handling %s with context: %+v", ctx.Msg.Type(), ctx)
   154  
   155  	if len(ctx.Invitation.Requests) > 0 {
   156  		go func() {
   157  			logger.Debugf("dispatching invitation attachment...")
   158  
   159  			err := deps.dispatchAttachmntFunc(ctx.Invitation.ID, ctx.MyDID, ctx.TheirDID)
   160  			if err != nil {
   161  				logger.Errorf("failed to dispatch attachment: %s", err.Error())
   162  			}
   163  		}()
   164  	}
   165  
   166  	return &stateDone{}, noAction, false, nil
   167  }
   168  
   169  type statePrepareResponse struct{}
   170  
   171  func (s *statePrepareResponse) Name() string {
   172  	return StateNamePrepareResponse
   173  }
   174  
   175  func (s *statePrepareResponse) Execute(ctx *context, deps *dependencies) (state, finisher, bool, error) {
   176  	logger.Debugf("handling %s with context: %+v", ctx.Msg.Type(), ctx)
   177  
   178  	// incoming Invitation
   179  	if ctx.ReuseConnection != "" || ctx.ReuseAnyConnection {
   180  		return s.connectionReuse(ctx, deps)
   181  	}
   182  
   183  	logger.Debugf("creating new connection using context: %+v", ctx)
   184  
   185  	connID, err := deps.didSvc.RespondTo(ctx.DIDExchangeInv, ctx.RouterConnections)
   186  	if err != nil {
   187  		return nil, nil, true, fmt.Errorf("didexchange service failed to handle inbound invitation: %w", err)
   188  	}
   189  
   190  	ctx.ConnectionID = connID
   191  
   192  	if len(ctx.Invitation.Requests) > 0 {
   193  		callbackState := &attachmentHandlingState{
   194  			ID:           ctx.Invitation.ID,
   195  			ConnectionID: connID,
   196  			Invitation:   ctx.Invitation,
   197  		}
   198  
   199  		err = deps.saveAttchStateFunc(callbackState)
   200  		if err != nil {
   201  			return nil, nil, true, fmt.Errorf("failed to save attachment handling state: %w", err)
   202  		}
   203  	}
   204  
   205  	return &stateDone{}, noAction, false, nil
   206  }
   207  
   208  func (s *statePrepareResponse) connectionReuse(ctx *context, deps *dependencies) (state, finisher, bool, error) {
   209  	logger.Debugf("reusing connection using context: %+v", ctx)
   210  
   211  	// TODO query needs to be improved: https://github.com/hyperledger/aries-framework-go/issues/2732
   212  	records, err := deps.connections.QueryConnectionRecords()
   213  	if err != nil {
   214  		return nil, nil, true, fmt.Errorf("connectionReuse: failed to fetch connection records: %w", err)
   215  	}
   216  
   217  	inv := ctx.Invitation
   218  
   219  	var (
   220  		record *connection.Record
   221  		found  bool
   222  	)
   223  
   224  	if ctx.ReuseAnyConnection {
   225  		for i := range inv.Services {
   226  			if s, ok := inv.Services[i].(string); ok {
   227  				record, found = findConnectionRecord(records, s)
   228  				if found {
   229  					break
   230  				}
   231  			}
   232  		}
   233  	} else {
   234  		record, found = findConnectionRecord(records, ctx.ReuseConnection)
   235  	}
   236  
   237  	if !found {
   238  		return nil, nil, true, errors.New("connectionReuse: no existing connection record found for the invitation")
   239  	}
   240  
   241  	ctx.ConnectionID = record.ConnectionID
   242  	ctx.MyDID = record.MyDID
   243  	ctx.TheirDID = record.TheirDID
   244  
   245  	if len(ctx.Invitation.Requests) > 0 {
   246  		callbackState := &attachmentHandlingState{
   247  			ID:           ctx.Invitation.ID,
   248  			ConnectionID: record.ConnectionID,
   249  			Invitation:   ctx.Invitation,
   250  		}
   251  
   252  		err = deps.saveAttchStateFunc(callbackState)
   253  		if err != nil {
   254  			return nil, nil, true, fmt.Errorf("failed to save attachment handling state: %w", err)
   255  		}
   256  	}
   257  
   258  	return &stateAwaitResponse{}, func(m service.Messenger) error {
   259  		return m.ReplyToMsg(
   260  			ctx.Msg,
   261  			service.NewDIDCommMsgMap(&HandshakeReuse{
   262  				ID:   uuid.New().String(),
   263  				Type: HandshakeReuseMsgType,
   264  			}),
   265  			ctx.MyDID,
   266  			ctx.TheirDID,
   267  		)
   268  	}, true, nil
   269  }
   270  
   271  type stateDone struct{}
   272  
   273  func (s *stateDone) Name() string {
   274  	return StateNameDone
   275  }
   276  
   277  func (s *stateDone) Execute(*context, *dependencies) (state, finisher, bool, error) {
   278  	return &stateDone{}, noAction, true, nil
   279  }
   280  
   281  func findConnectionRecord(records []*connection.Record, theirDID string) (*connection.Record, bool) {
   282  	for i := range records {
   283  		record := records[i]
   284  
   285  		if record.State != didexchange.StateIDCompleted {
   286  			continue
   287  		}
   288  
   289  		// we may recognize their DID by either:
   290  		//   - having received an invitation with their "public" DID (record.InvitationDID)
   291  		//   - them providing a "ledger-less" DID during a prior DID-Exchange
   292  		if record.InvitationDID == theirDID || record.TheirDID == theirDID {
   293  			return record, true
   294  		}
   295  	}
   296  
   297  	return nil, false
   298  }