github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/outofband/states.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package outofband 8 9 import ( 10 "errors" 11 "fmt" 12 13 "github.com/google/uuid" 14 15 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" 16 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange" 17 "github.com/hyperledger/aries-framework-go/pkg/store/connection" 18 ) 19 20 const ( 21 // StateNameInitial is the initial state. 22 StateNameInitial = "initial" 23 // StateNameAwaitResponse is the state where a sender or a receiver are awaiting a response. 24 StateNameAwaitResponse = "await-response" 25 // StateNamePrepareResponse is the state where a receiver is preparing a response to the sender. 26 StateNamePrepareResponse = "prepare-response" 27 // StateNameDone is the final state. 28 StateNameDone = "done" 29 30 connectionRecordCompletedState = "completed" 31 ) 32 33 type finisher func(service.Messenger) error 34 35 func noAction(service.Messenger) error { 36 return nil 37 } 38 39 type dependencies struct { 40 connections connectionRecorder 41 didSvc didExchSvc 42 saveAttchStateFunc func(*attachmentHandlingState) error 43 dispatchAttachmntFunc func(string, string, string) error 44 } 45 46 // The outofband protocol's state. 47 type state interface { 48 Name() string 49 Execute(*context, *dependencies) (state, finisher, bool, error) 50 } 51 52 func stateFromName(n string) (state, error) { 53 states := []state{ 54 &stateInitial{}, 55 &stateAwaitResponse{}, 56 &statePrepareResponse{}, 57 &stateDone{}, 58 } 59 60 for i := range states { 61 if states[i].Name() == n { 62 return states[i], nil 63 } 64 } 65 66 return nil, fmt.Errorf("unrecognized state name: %s", n) 67 } 68 69 func requiresApproval(msg service.DIDCommMsg) bool { 70 switch msg.Type() { 71 case InvitationMsgType, HandshakeReuseMsgType: 72 return true 73 } 74 75 return false 76 } 77 78 type stateInitial struct{} 79 80 func (s *stateInitial) Name() string { 81 return StateNameInitial 82 } 83 84 func (s *stateInitial) Execute(ctx *context, _ *dependencies) (state, finisher, bool, error) { 85 if ctx.Inbound { // inbound invitation 86 return &statePrepareResponse{}, noAction, false, nil 87 } 88 89 // outbound invitation 90 return &stateAwaitResponse{}, func(m service.Messenger) error { 91 return m.Send(ctx.Msg, ctx.MyDID, ctx.TheirDID) 92 }, true, nil 93 } 94 95 type stateAwaitResponse struct{} 96 97 func (s *stateAwaitResponse) Name() string { 98 return StateNameAwaitResponse 99 } 100 101 func (s *stateAwaitResponse) Execute(ctx *context, deps *dependencies) (state, finisher, bool, error) { 102 if !ctx.Inbound { 103 return nil, nil, true, fmt.Errorf("cannot execute '%s' for outbound messages", s.Name()) 104 } 105 106 // inbound HandshakeReuse or HandshakeReuseAccepted 107 if ctx.Msg.Type() == HandshakeReuseMsgType { 108 return s.handleHandshakeReuse(ctx, deps) 109 } 110 111 return s.handleHandshakeReuseAccepted(ctx, deps) 112 } 113 114 func (s *stateAwaitResponse) handleHandshakeReuse(ctx *context, deps *dependencies) (state, finisher, bool, error) { 115 // incoming HandshakeReuse 116 logger.Debugf("handling %s with context: %+v", ctx.Msg.Type(), ctx) 117 118 connID, err := deps.connections.GetConnectionIDByDIDs(ctx.MyDID, ctx.TheirDID) 119 if err != nil { 120 return nil, nil, true, fmt.Errorf( 121 "failed to fetch connection ID [myDID=%s theirDID=%s]: %w", 122 ctx.MyDID, ctx.TheirDID, err, 123 ) 124 } 125 126 record, err := deps.connections.GetConnectionRecord(connID) 127 if err != nil { 128 return nil, nil, true, fmt.Errorf("failed to fetch connection record [connID=%s]: %w", connID, err) 129 } 130 131 if record.State != connectionRecordCompletedState { 132 return nil, nil, true, fmt.Errorf( 133 "unexpected state for connection with ID=%s: expected '%s' got '%s'", 134 connID, connectionRecordCompletedState, record.State, 135 ) 136 } 137 138 return &stateDone{}, func(m service.Messenger) error { 139 return m.ReplyToMsg( 140 ctx.Msg, 141 service.NewDIDCommMsgMap(&HandshakeReuseAccepted{ 142 ID: uuid.New().String(), 143 Type: HandshakeReuseAcceptedMsgType, 144 }), 145 ctx.MyDID, 146 ctx.TheirDID, 147 ) 148 }, false, nil 149 } 150 151 func (s *stateAwaitResponse) handleHandshakeReuseAccepted( 152 ctx *context, deps *dependencies) (state, finisher, bool, error) { 153 logger.Debugf("handling %s with context: %+v", ctx.Msg.Type(), ctx) 154 155 if len(ctx.Invitation.Requests) > 0 { 156 go func() { 157 logger.Debugf("dispatching invitation attachment...") 158 159 err := deps.dispatchAttachmntFunc(ctx.Invitation.ID, ctx.MyDID, ctx.TheirDID) 160 if err != nil { 161 logger.Errorf("failed to dispatch attachment: %s", err.Error()) 162 } 163 }() 164 } 165 166 return &stateDone{}, noAction, false, nil 167 } 168 169 type statePrepareResponse struct{} 170 171 func (s *statePrepareResponse) Name() string { 172 return StateNamePrepareResponse 173 } 174 175 func (s *statePrepareResponse) Execute(ctx *context, deps *dependencies) (state, finisher, bool, error) { 176 logger.Debugf("handling %s with context: %+v", ctx.Msg.Type(), ctx) 177 178 // incoming Invitation 179 if ctx.ReuseConnection != "" || ctx.ReuseAnyConnection { 180 return s.connectionReuse(ctx, deps) 181 } 182 183 logger.Debugf("creating new connection using context: %+v", ctx) 184 185 connID, err := deps.didSvc.RespondTo(ctx.DIDExchangeInv, ctx.RouterConnections) 186 if err != nil { 187 return nil, nil, true, fmt.Errorf("didexchange service failed to handle inbound invitation: %w", err) 188 } 189 190 ctx.ConnectionID = connID 191 192 if len(ctx.Invitation.Requests) > 0 { 193 callbackState := &attachmentHandlingState{ 194 ID: ctx.Invitation.ID, 195 ConnectionID: connID, 196 Invitation: ctx.Invitation, 197 } 198 199 err = deps.saveAttchStateFunc(callbackState) 200 if err != nil { 201 return nil, nil, true, fmt.Errorf("failed to save attachment handling state: %w", err) 202 } 203 } 204 205 return &stateDone{}, noAction, false, nil 206 } 207 208 func (s *statePrepareResponse) connectionReuse(ctx *context, deps *dependencies) (state, finisher, bool, error) { 209 logger.Debugf("reusing connection using context: %+v", ctx) 210 211 // TODO query needs to be improved: https://github.com/hyperledger/aries-framework-go/issues/2732 212 records, err := deps.connections.QueryConnectionRecords() 213 if err != nil { 214 return nil, nil, true, fmt.Errorf("connectionReuse: failed to fetch connection records: %w", err) 215 } 216 217 inv := ctx.Invitation 218 219 var ( 220 record *connection.Record 221 found bool 222 ) 223 224 if ctx.ReuseAnyConnection { 225 for i := range inv.Services { 226 if s, ok := inv.Services[i].(string); ok { 227 record, found = findConnectionRecord(records, s) 228 if found { 229 break 230 } 231 } 232 } 233 } else { 234 record, found = findConnectionRecord(records, ctx.ReuseConnection) 235 } 236 237 if !found { 238 return nil, nil, true, errors.New("connectionReuse: no existing connection record found for the invitation") 239 } 240 241 ctx.ConnectionID = record.ConnectionID 242 ctx.MyDID = record.MyDID 243 ctx.TheirDID = record.TheirDID 244 245 if len(ctx.Invitation.Requests) > 0 { 246 callbackState := &attachmentHandlingState{ 247 ID: ctx.Invitation.ID, 248 ConnectionID: record.ConnectionID, 249 Invitation: ctx.Invitation, 250 } 251 252 err = deps.saveAttchStateFunc(callbackState) 253 if err != nil { 254 return nil, nil, true, fmt.Errorf("failed to save attachment handling state: %w", err) 255 } 256 } 257 258 return &stateAwaitResponse{}, func(m service.Messenger) error { 259 return m.ReplyToMsg( 260 ctx.Msg, 261 service.NewDIDCommMsgMap(&HandshakeReuse{ 262 ID: uuid.New().String(), 263 Type: HandshakeReuseMsgType, 264 }), 265 ctx.MyDID, 266 ctx.TheirDID, 267 ) 268 }, true, nil 269 } 270 271 type stateDone struct{} 272 273 func (s *stateDone) Name() string { 274 return StateNameDone 275 } 276 277 func (s *stateDone) Execute(*context, *dependencies) (state, finisher, bool, error) { 278 return &stateDone{}, noAction, true, nil 279 } 280 281 func findConnectionRecord(records []*connection.Record, theirDID string) (*connection.Record, bool) { 282 for i := range records { 283 record := records[i] 284 285 if record.State != didexchange.StateIDCompleted { 286 continue 287 } 288 289 // we may recognize their DID by either: 290 // - having received an invitation with their "public" DID (record.InvitationDID) 291 // - them providing a "ledger-less" DID during a prior DID-Exchange 292 if record.InvitationDID == theirDID || record.TheirDID == theirDID { 293 return record, true 294 } 295 } 296 297 return nil, false 298 }