github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/protocol/messagepickup/service.go (about) 1 /* 2 Copyright Scoir Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package messagepickup 8 9 import ( 10 "encoding/json" 11 "fmt" 12 "sync" 13 "time" 14 15 "github.com/google/uuid" 16 "github.com/pkg/errors" 17 18 "github.com/hyperledger/aries-framework-go/pkg/common/log" 19 "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/service" 20 "github.com/hyperledger/aries-framework-go/pkg/didcomm/dispatcher" 21 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" 22 "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" 23 "github.com/hyperledger/aries-framework-go/pkg/store/connection" 24 "github.com/hyperledger/aries-framework-go/spi/storage" 25 ) 26 27 const ( 28 // MessagePickup defines the protocol name. 29 MessagePickup = "messagepickup" 30 // Spec defines the protocol spec. 31 Spec = "https://didcomm.org/messagepickup/1.0/" 32 // StatusMsgType defines the protocol propose-credential message type. 33 StatusMsgType = Spec + "status" 34 // StatusRequestMsgType defines the protocol propose-credential message type. 35 StatusRequestMsgType = Spec + "status-request" 36 // BatchPickupMsgType defines the protocol offer-credential message type. 37 BatchPickupMsgType = Spec + "batch-pickup" 38 // BatchMsgType defines the protocol offer-credential message type. 39 BatchMsgType = Spec + "batch" 40 // NoopMsgType defines the protocol request-credential message type. 41 NoopMsgType = Spec + "noop" 42 ) 43 44 const ( 45 updateTimeout = 50 * time.Second 46 47 // Namespace is namespace of messagepickup store name. 48 Namespace = "mailbox" 49 ) 50 51 // ErrConnectionNotFound connection not found error. 52 var ( 53 ErrConnectionNotFound = errors.New("connection not found") 54 logger = log.New("aries-framework/messagepickup") 55 ) 56 57 type provider interface { 58 OutboundDispatcher() dispatcher.Outbound 59 StorageProvider() storage.Provider 60 ProtocolStateStorageProvider() storage.Provider 61 InboundMessageHandler() transport.InboundMessageHandler 62 Packager() transport.Packager 63 } 64 65 type connections interface { 66 GetConnectionRecord(string) (*connection.Record, error) 67 } 68 69 // Service for the messagepickup protocol. 70 type Service struct { 71 service.Action 72 service.Message 73 connectionLookup connections 74 outbound dispatcher.Outbound 75 msgStore storage.Store 76 packager transport.Packager 77 msgHandler transport.InboundMessageHandler 78 batchMap map[string]chan Batch 79 batchMapLock sync.RWMutex 80 statusMap map[string]chan Status 81 statusMapLock sync.RWMutex 82 inboxLock sync.Mutex 83 initialized bool 84 } 85 86 // New returns the messagepickup service. 87 func New(prov provider) (*Service, error) { 88 svc := Service{} 89 90 err := svc.Initialize(prov) 91 if err != nil { 92 return nil, err 93 } 94 95 return &svc, nil 96 } 97 98 // Initialize initializes the Service. If Initialize succeeds, any further call is a no-op. 99 func (s *Service) Initialize(p interface{}) error { 100 if s.initialized { 101 return nil 102 } 103 104 prov, ok := p.(provider) 105 if !ok { 106 return fmt.Errorf("expected provider of type `%T`, got type `%T`", provider(nil), p) 107 } 108 109 store, err := prov.StorageProvider().OpenStore(Namespace) 110 if err != nil { 111 return fmt.Errorf("open mailbox store : %w", err) 112 } 113 114 connectionLookup, err := connection.NewLookup(prov) 115 if err != nil { 116 return err 117 } 118 119 s.outbound = prov.OutboundDispatcher() 120 s.msgStore = store 121 s.connectionLookup = connectionLookup 122 s.packager = prov.Packager() 123 s.msgHandler = prov.InboundMessageHandler() 124 s.batchMap = make(map[string]chan Batch) 125 s.statusMap = make(map[string]chan Status) 126 127 s.initialized = true 128 129 return nil 130 } 131 132 // HandleInbound handles inbound message pick up messages. 133 func (s *Service) HandleInbound(msg service.DIDCommMsg, ctx service.DIDCommContext) (string, error) { 134 // perform action asynchronously 135 go func() { 136 var err error 137 138 switch msg.Type() { 139 case StatusMsgType: 140 err = s.handleStatus(msg) 141 case StatusRequestMsgType: 142 err = s.handleStatusRequest(msg, ctx.MyDID(), ctx.TheirDID()) 143 case BatchPickupMsgType: 144 err = s.handleBatchPickup(msg, ctx.MyDID(), ctx.TheirDID()) 145 case BatchMsgType: 146 err = s.handleBatch(msg) 147 case NoopMsgType: 148 err = s.handleNoop(msg) 149 } 150 151 if err != nil { 152 logger.Errorf("Error handling message: (%w)\n", err) 153 } 154 }() 155 156 return msg.ID(), nil 157 } 158 159 // HandleOutbound adherence to dispatcher.ProtocolService. 160 func (s *Service) HandleOutbound(_ service.DIDCommMsg, _, _ string) (string, error) { 161 return "", errors.New("not implemented") 162 } 163 164 // Accept checks whether the service can handle the message type. 165 func (s *Service) Accept(msgType string) bool { 166 switch msgType { 167 case BatchPickupMsgType, BatchMsgType, StatusRequestMsgType, StatusMsgType, NoopMsgType: 168 return true 169 } 170 171 return false 172 } 173 174 // Name of the service. 175 func (s *Service) Name() string { 176 return MessagePickup 177 } 178 179 func (s *Service) handleStatus(msg service.DIDCommMsg) error { 180 // unmarshal the payload 181 statusMsg := &Status{} 182 183 err := msg.Decode(statusMsg) 184 if err != nil { 185 return fmt.Errorf("status message unmarshal: %w", err) 186 } 187 188 // check if there are any channels registered for the message ID 189 statusCh := s.getStatusCh(statusMsg.ID) 190 if statusCh != nil { 191 // invoke the channel for the incoming message 192 statusCh <- *statusMsg 193 } 194 195 return nil 196 } 197 198 func (s *Service) handleStatusRequest(msg service.DIDCommMsg, myDID, theirDID string) error { 199 s.inboxLock.Lock() 200 defer s.inboxLock.Unlock() 201 202 // unmarshal the payload 203 request := &StatusRequest{} 204 205 err := msg.Decode(request) 206 if err != nil { 207 return fmt.Errorf("status request message unmarshal: %w", err) 208 } 209 210 logger.Debugf("retrieving stored messages for %s\n", theirDID) 211 212 outbox, err := s.getInbox(theirDID) 213 if err != nil { 214 return fmt.Errorf("error in status request getting inbox: %w", err) 215 } 216 217 resp := &Status{ 218 Type: StatusMsgType, 219 ID: msg.ID(), 220 MessageCount: outbox.MessageCount, 221 DurationWaited: int(time.Since(outbox.LastDeliveredTime).Seconds()), 222 LastAddedTime: outbox.LastAddedTime, 223 LastDeliveredTime: outbox.LastDeliveredTime, 224 LastRemovedTime: outbox.LastRemovedTime, 225 TotalSize: outbox.TotalSize, 226 Thread: &decorator.Thread{ 227 PID: request.Thread.ID, 228 }, 229 } 230 231 msgBytes, err := json.Marshal(resp) 232 if err != nil { 233 return fmt.Errorf("marshal batch: %w", err) 234 } 235 236 msgMap, err := service.ParseDIDCommMsgMap(msgBytes) 237 if err != nil { 238 return fmt.Errorf("parse batch into didcomm msg map: %w", err) 239 } 240 241 return s.outbound.SendToDID(msgMap, myDID, theirDID) 242 } 243 244 func (s *Service) handleBatchPickup(msg service.DIDCommMsg, myDID, theirDID string) error { 245 s.inboxLock.Lock() 246 defer s.inboxLock.Unlock() 247 248 // unmarshal the payload 249 request := &BatchPickup{} 250 251 err := msg.Decode(request) 252 if err != nil { 253 return fmt.Errorf("batch pickup message unmarshal : %w", err) 254 } 255 256 outbox, err := s.getInbox(theirDID) 257 if err != nil { 258 return fmt.Errorf("batch pickup get inbox: %w", err) 259 } 260 261 msgs, err := outbox.DecodeMessages() 262 if err != nil { 263 return fmt.Errorf("batch pickup decode : %w", err) 264 } 265 266 end := len(msgs) 267 if request.BatchSize < end { 268 end = request.BatchSize 269 } 270 271 outbox.LastDeliveredTime = time.Now() 272 outbox.LastRemovedTime = time.Now() 273 274 err = outbox.EncodeMessages(msgs[end:]) 275 if err != nil { 276 return fmt.Errorf("batch pickup encode: %w", err) 277 } 278 279 err = s.putInbox(theirDID, outbox) 280 if err != nil { 281 return fmt.Errorf("batch pick up put inbox: %w", err) 282 } 283 284 msgs = msgs[0:end] 285 286 batch := Batch{ 287 Type: BatchMsgType, 288 ID: msg.ID(), 289 Messages: msgs, 290 } 291 292 msgBytes, err := json.Marshal(batch) 293 if err != nil { 294 return fmt.Errorf("marshal batch: %w", err) 295 } 296 297 msgMap, err := service.ParseDIDCommMsgMap(msgBytes) 298 if err != nil { 299 return fmt.Errorf("parse batch into didcomm msg map: %w", err) 300 } 301 302 return s.outbound.SendToDID(msgMap, myDID, theirDID) 303 } 304 305 func (s *Service) handleBatch(msg service.DIDCommMsg) error { 306 // unmarshal the payload 307 batchMsg := &Batch{} 308 309 err := msg.Decode(batchMsg) 310 if err != nil { 311 return fmt.Errorf("batch message unmarshal : %w", err) 312 } 313 314 // check if there are any channels registered for the message ID 315 batchCh := s.getBatchCh(batchMsg.ID) 316 317 if batchCh != nil { 318 // invoke the channel for the incoming message 319 batchCh <- *batchMsg 320 } 321 322 return nil 323 } 324 325 func (s *Service) handleNoop(msg service.DIDCommMsg) error { 326 // unmarshal the payload 327 request := &Noop{} 328 329 err := msg.Decode(request) 330 if err != nil { 331 return fmt.Errorf("noop message unmarshal : %w", err) 332 } 333 334 return nil 335 } 336 337 type inbox struct { 338 DID string `json:"DID"` 339 MessageCount int `json:"message_count"` 340 LastAddedTime time.Time `json:"last_added_time,omitempty"` 341 LastDeliveredTime time.Time `json:"last_delivered_time,omitempty"` 342 LastRemovedTime time.Time `json:"last_removed_time,omitempty"` 343 TotalSize int `json:"total_size,omitempty"` 344 Messages json.RawMessage `json:"messages"` 345 } 346 347 // DecodeMessages Messages. 348 func (r *inbox) DecodeMessages() ([]*Message, error) { 349 var out []*Message 350 351 var err error 352 353 if r.Messages != nil { 354 err = json.Unmarshal(r.Messages, &out) 355 } 356 357 return out, err 358 } 359 360 // EncodeMessages Messages. 361 func (r *inbox) EncodeMessages(msg []*Message) error { 362 d, err := json.Marshal(msg) 363 if err != nil { 364 return fmt.Errorf("unable to marshal: %w", err) 365 } 366 367 r.Messages = d 368 r.MessageCount = len(msg) 369 r.TotalSize = len(d) 370 371 return nil 372 } 373 374 // AddMessage add message to inbox. 375 func (s *Service) AddMessage(message []byte, theirDID string) error { 376 s.inboxLock.Lock() 377 defer s.inboxLock.Unlock() 378 379 outbox, err := s.createInbox(theirDID) 380 if err != nil { 381 return fmt.Errorf("unable to pull messages: %w", err) 382 } 383 384 msgs, err := outbox.DecodeMessages() 385 if err != nil { 386 return fmt.Errorf("unable to decode messages: %w", err) 387 } 388 389 m := Message{ 390 ID: uuid.New().String(), 391 AddedTime: time.Now(), 392 Message: message, 393 } 394 395 msgs = append(msgs, &m) 396 397 outbox.LastDeliveredTime = time.Now() 398 outbox.LastRemovedTime = outbox.LastDeliveredTime 399 400 err = outbox.EncodeMessages(msgs) 401 if err != nil { 402 return fmt.Errorf("unable to encode messages: %w", err) 403 } 404 405 err = s.putInbox(theirDID, outbox) 406 if err != nil { 407 return fmt.Errorf("unable to put messages: %w", err) 408 } 409 410 return nil 411 } 412 413 func (s *Service) createInbox(theirDID string) (*inbox, error) { 414 msgs, err := s.getInbox(theirDID) 415 if err != nil && errors.Is(err, storage.ErrDataNotFound) { 416 msgs = &inbox{DID: theirDID} 417 418 msgBytes, e := json.Marshal(msgs) 419 if e != nil { 420 return nil, e 421 } 422 423 e = s.msgStore.Put(theirDID, msgBytes) 424 if e != nil { 425 return nil, e 426 } 427 428 return msgs, nil 429 } 430 431 return msgs, err 432 } 433 434 func (s *Service) getInbox(theirDID string) (*inbox, error) { 435 msgs := &inbox{DID: theirDID} 436 437 b, err := s.msgStore.Get(theirDID) 438 if err != nil { 439 return nil, err 440 } 441 442 err = json.Unmarshal(b, msgs) 443 if err != nil { 444 return nil, err 445 } 446 447 return msgs, nil 448 } 449 450 func (s *Service) putInbox(theirDID string, o *inbox) error { 451 b, err := json.Marshal(o) 452 if err != nil { 453 return err 454 } 455 456 return s.msgStore.Put(theirDID, b) 457 } 458 459 // StatusRequest request a status message. 460 func (s *Service) StatusRequest(connectionID string) (*Status, error) { 461 // get the connection record for the ID to fetch DID information 462 conn, err := s.getConnection(connectionID) 463 if err != nil { 464 return nil, err 465 } 466 467 // generate message ID 468 msgID := uuid.New().String() 469 470 // register chan for callback processing 471 statusCh := make(chan Status) 472 s.setStatusCh(msgID, statusCh) 473 474 defer s.setStatusCh(msgID, nil) 475 476 // create request message 477 req := &StatusRequest{ 478 Type: StatusRequestMsgType, 479 ID: msgID, 480 Thread: &decorator.Thread{ 481 PID: uuid.New().String(), 482 }, 483 } 484 485 // send message to the router 486 if err := s.outbound.SendToDID(req, conn.MyDID, conn.TheirDID); err != nil { 487 return nil, fmt.Errorf("send route request: %w", err) 488 } 489 490 // callback processing (to make this function look like a sync function) 491 var sts *Status 492 select { 493 case s := <-statusCh: 494 sts = &s 495 // TODO https://github.com/hyperledger/aries-framework-go/issues/1134 configure this timeout at decorator level 496 case <-time.After(updateTimeout): 497 return nil, errors.New("timeout waiting for status request") 498 } 499 500 return sts, nil 501 } 502 503 // BatchPickup a request to have multiple waiting messages sent inside a batch message. 504 func (s *Service) BatchPickup(connectionID string, size int) (int, error) { 505 // get the connection record for the ID to fetch DID information 506 conn, err := s.getConnection(connectionID) 507 if err != nil { 508 return -1, err 509 } 510 511 // generate message ID 512 msgID := uuid.New().String() 513 514 // register chan for callback processing 515 batchCh := make(chan Batch) 516 s.setBatchCh(msgID, batchCh) 517 518 defer s.setBatchCh(msgID, nil) 519 520 // create request message 521 req := &BatchPickup{ 522 Type: BatchPickupMsgType, 523 ID: msgID, 524 BatchSize: size, 525 } 526 527 msgBytes, err := json.Marshal(req) 528 if err != nil { 529 return -1, fmt.Errorf("marshal req: %w", err) 530 } 531 532 msgMap, err := service.ParseDIDCommMsgMap(msgBytes) 533 if err != nil { 534 return -1, fmt.Errorf("parse req into didcomm msg map: %w", err) 535 } 536 537 // send message to the router 538 if err := s.outbound.SendToDID(msgMap, conn.MyDID, conn.TheirDID); err != nil { 539 return -1, fmt.Errorf("send batch pickup request: %w", err) 540 } 541 542 // callback processing (to make this function look like a sync function) 543 var processed int 544 select { 545 case batchResp := <-batchCh: 546 for _, msg := range batchResp.Messages { 547 err := s.handle(msg) 548 if err != nil { 549 logger.Errorf("error handling batch message %s: %w", msg.ID, err) 550 551 continue 552 } 553 processed++ 554 } 555 // TODO https://github.com/hyperledger/aries-framework-go/issues/1134 configure this timeout at decorator level 556 case <-time.After(updateTimeout): 557 return -1, errors.New("timeout waiting for batch") 558 } 559 560 return processed, nil 561 } 562 563 // Noop a noop message. 564 func (s *Service) Noop(connectionID string) error { 565 // get the connection record for the ID to fetch DID information 566 conn, err := s.getConnection(connectionID) 567 if err != nil { 568 return err 569 } 570 571 noop := &Noop{ID: uuid.New().String(), Type: NoopMsgType} 572 573 msgBytes, err := json.Marshal(noop) 574 if err != nil { 575 return fmt.Errorf("marshal noop: %w", err) 576 } 577 578 msgMap, err := service.ParseDIDCommMsgMap(msgBytes) 579 if err != nil { 580 return fmt.Errorf("parse noop into didcomm msg map: %w", err) 581 } 582 583 if err := s.outbound.SendToDID(msgMap, conn.MyDID, conn.TheirDID); err != nil { 584 return fmt.Errorf("send noop request: %w", err) 585 } 586 587 return nil 588 } 589 590 func (s *Service) getConnection(routerConnID string) (*connection.Record, error) { 591 conn, err := s.connectionLookup.GetConnectionRecord(routerConnID) 592 if err != nil { 593 if errors.Is(err, storage.ErrDataNotFound) { 594 return nil, ErrConnectionNotFound 595 } 596 597 return nil, fmt.Errorf("fetch connection record from store : %w", err) 598 } 599 600 return conn, nil 601 } 602 603 func (s *Service) getBatchCh(msgID string) chan Batch { 604 s.batchMapLock.RLock() 605 defer s.batchMapLock.RUnlock() 606 607 return s.batchMap[msgID] 608 } 609 610 func (s *Service) setBatchCh(msgID string, batchCh chan Batch) { 611 s.batchMapLock.Lock() 612 defer s.batchMapLock.Unlock() 613 614 if batchCh == nil { 615 delete(s.batchMap, msgID) 616 } else { 617 s.batchMap[msgID] = batchCh 618 } 619 } 620 621 func (s *Service) getStatusCh(msgID string) chan Status { 622 s.statusMapLock.RLock() 623 defer s.statusMapLock.RUnlock() 624 625 return s.statusMap[msgID] 626 } 627 628 func (s *Service) setStatusCh(msgID string, statusCh chan Status) { 629 s.statusMapLock.Lock() 630 defer s.statusMapLock.Unlock() 631 632 if statusCh == nil { 633 delete(s.statusMap, msgID) 634 } else { 635 s.statusMap[msgID] = statusCh 636 } 637 } 638 639 func (s *Service) handle(msg *Message) error { 640 unpackMsg, err := s.packager.UnpackMessage(msg.Message) 641 if err != nil { 642 return fmt.Errorf("failed to unpack msg: %w", err) 643 } 644 645 trans := &decorator.Transport{} 646 err = json.Unmarshal(unpackMsg.Message, trans) 647 648 if err != nil { 649 return fmt.Errorf("unmarshal transport decorator : %w", err) 650 } 651 652 messageHandler := s.msgHandler 653 654 err = messageHandler(unpackMsg) 655 if err != nil { 656 return fmt.Errorf("incoming msg processing failed: %w", err) 657 } 658 659 return nil 660 }