github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/sqlite/messagestore.go (about) 1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sqlite 16 17 import ( 18 "context" 19 "database/sql" 20 "encoding/hex" 21 "errors" 22 "fmt" 23 "strconv" 24 "strings" 25 "time" 26 27 log "github.com/golang/glog" 28 29 "github.com/google/fleetspeak/fleetspeak/src/common" 30 "github.com/google/fleetspeak/fleetspeak/src/server/db" 31 32 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 33 "google.golang.org/protobuf/proto" 34 anypb "google.golang.org/protobuf/types/known/anypb" 35 tspb "google.golang.org/protobuf/types/known/timestamppb" 36 ) 37 38 // dbMessage matches the schema of the messages table, optionally joined to the 39 // pending_messages table. 40 type dbMessage struct { 41 messageID string 42 sourceClientID string 43 sourceServiceName string 44 sourceMessageID string 45 destinationClientID string 46 destinationServiceName string 47 messageType string 48 creationTimeSeconds int64 49 creationTimeNanos int32 50 processedTimeSeconds sql.NullInt64 51 processedTimeNanos sql.NullInt64 52 validationInfo []byte 53 failed sql.NullBool 54 failedReason sql.NullString 55 retryCount uint32 56 dataTypeURL sql.NullString 57 dataValue []byte 58 annotations []byte 59 } 60 61 func toMicro(t time.Time) int64 { 62 return t.UnixNano() / 1000 63 } 64 65 func (d *Datastore) SetMessageResult(ctx context.Context, dest common.ClientID, id common.MessageID, res *fspb.MessageResult) error { 66 d.l.Lock() 67 defer d.l.Unlock() 68 return d.runInTx(func(tx *sql.Tx) error { return d.trySetMessageResult(ctx, tx, id, res) }) 69 } 70 71 func (d *Datastore) trySetMessageResult(ctx context.Context, tx *sql.Tx, id common.MessageID, res *fspb.MessageResult) error { 72 dbm := dbMessage{ 73 messageID: id.String(), 74 processedTimeSeconds: sql.NullInt64{Valid: true, Int64: res.ProcessedTime.Seconds}, 75 processedTimeNanos: sql.NullInt64{Valid: true, Int64: int64(res.ProcessedTime.Nanos)}, 76 } 77 if res.Failed { 78 dbm.failed = sql.NullBool{Valid: true, Bool: true} 79 dbm.failedReason = sql.NullString{Valid: true, String: res.FailedReason} 80 } 81 _, err := tx.ExecContext(ctx, "UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id=?", 82 dbm.failed, dbm.failedReason, dbm.processedTimeSeconds, dbm.processedTimeNanos, dbm.messageID) 83 if err != nil { 84 return err 85 } 86 _, err = tx.ExecContext(ctx, "DELETE FROM pending_messages WHERE message_id=?", dbm.messageID) 87 return err 88 } 89 90 func toClientIDString(b []byte) string { 91 if len(b) == 0 { 92 return "" 93 } 94 id, err := common.BytesToClientID(b) 95 if err != nil { 96 log.Fatalf("Could't parse ClientID(%v): %v", b, err) 97 } 98 return id.String() 99 } 100 101 func fromClientIDString(s string) (b []byte) { 102 if s == "" { 103 return nil 104 } 105 cid, err := common.StringToClientID(s) 106 if err != nil { 107 log.Fatalf("Couldn't parse ClientID(%v): %v", s, err) 108 return nil 109 } 110 return cid.Bytes() 111 } 112 113 func fromNULLString(s sql.NullString) string { 114 if !s.Valid { 115 return "" 116 } 117 return s.String 118 } 119 120 func fromMessageProto(m *fspb.Message) (*dbMessage, error) { 121 id, err := common.BytesToMessageID(m.MessageId) 122 if err != nil { 123 return nil, err 124 } 125 dbm := &dbMessage{ 126 messageID: id.String(), 127 messageType: m.MessageType, 128 } 129 if m.Source != nil { 130 dbm.sourceClientID = toClientIDString(m.Source.ClientId) 131 dbm.sourceServiceName = m.Source.ServiceName 132 } 133 if m.Destination != nil { 134 dbm.destinationClientID = toClientIDString(m.Destination.ClientId) 135 dbm.destinationServiceName = m.Destination.ServiceName 136 } 137 if len(m.SourceMessageId) != 0 { 138 dbm.sourceMessageID = hex.EncodeToString(m.SourceMessageId) 139 } 140 if m.CreationTime != nil { 141 dbm.creationTimeSeconds = m.CreationTime.Seconds 142 dbm.creationTimeNanos = m.CreationTime.Nanos 143 } 144 if m.Result != nil { 145 r := m.Result 146 if r.ProcessedTime != nil { 147 dbm.processedTimeSeconds = sql.NullInt64{Int64: r.ProcessedTime.Seconds, Valid: true} 148 dbm.processedTimeNanos = sql.NullInt64{Int64: int64(r.ProcessedTime.Nanos), Valid: true} 149 } 150 if r.Failed { 151 dbm.failed = sql.NullBool{Bool: true, Valid: true} 152 dbm.failedReason = sql.NullString{String: r.FailedReason, Valid: true} 153 } 154 } 155 if m.Data != nil { 156 dbm.dataTypeURL = sql.NullString{String: m.Data.TypeUrl, Valid: true} 157 dbm.dataValue = m.Data.Value 158 } 159 if m.ValidationInfo != nil { 160 b, err := proto.Marshal(m.ValidationInfo) 161 if err != nil { 162 return nil, err 163 } 164 dbm.validationInfo = b 165 } 166 if m.Annotations != nil { 167 b, err := proto.Marshal(m.Annotations) 168 if err != nil { 169 return nil, err 170 } 171 dbm.annotations = b 172 } 173 return dbm, nil 174 } 175 176 func toMessageResultProto(m *dbMessage) *fspb.MessageResult { 177 if !m.processedTimeSeconds.Valid { 178 return nil 179 } 180 181 ret := &fspb.MessageResult{ 182 ProcessedTime: &tspb.Timestamp{ 183 Seconds: m.processedTimeSeconds.Int64, 184 Nanos: int32(m.processedTimeNanos.Int64)}, 185 Failed: m.failed.Valid && m.failed.Bool, 186 } 187 188 if m.failedReason.Valid { 189 ret.FailedReason = m.failedReason.String 190 } 191 return ret 192 } 193 194 func toMessageProto(m *dbMessage) (*fspb.Message, error) { 195 mid, err := common.StringToMessageID(m.messageID) 196 if err != nil { 197 return nil, err 198 } 199 bsmid, err := hex.DecodeString(m.sourceMessageID) 200 if err != nil { 201 return nil, err 202 } 203 pm := &fspb.Message{ 204 MessageId: mid.Bytes(), 205 Source: &fspb.Address{ 206 ClientId: fromClientIDString(m.sourceClientID), 207 ServiceName: m.sourceServiceName, 208 }, 209 SourceMessageId: bsmid, 210 Destination: &fspb.Address{ 211 ClientId: fromClientIDString(m.destinationClientID), 212 ServiceName: m.destinationServiceName, 213 }, 214 MessageType: m.messageType, 215 CreationTime: &tspb.Timestamp{ 216 Seconds: m.creationTimeSeconds, 217 Nanos: m.creationTimeNanos, 218 }, 219 Result: toMessageResultProto(m), 220 } 221 if m.dataTypeURL.Valid { 222 pm.Data = &anypb.Any{ 223 TypeUrl: m.dataTypeURL.String, 224 Value: m.dataValue, 225 } 226 } 227 if len(m.validationInfo) > 0 { 228 v := &fspb.ValidationInfo{} 229 if err := proto.Unmarshal(m.validationInfo, v); err != nil { 230 return nil, err 231 } 232 pm.ValidationInfo = v 233 } 234 if len(m.annotations) > 0 { 235 a := &fspb.Annotations{} 236 if err := proto.Unmarshal(m.annotations, a); err != nil { 237 return nil, err 238 } 239 pm.Annotations = a 240 } 241 return pm, nil 242 } 243 244 func (d *Datastore) StoreMessages(ctx context.Context, msgs []*fspb.Message, contact db.ContactID) error { 245 d.l.Lock() 246 defer d.l.Unlock() 247 248 ids := make([]string, 0, len(msgs)) 249 250 return d.runInTx(func(tx *sql.Tx) error { 251 for _, m := range msgs { 252 dbm, err := fromMessageProto(m) 253 if err != nil { 254 return err 255 } 256 // If it is already processed, we don't want to save m.Data. 257 if m.Result != nil { 258 dbm.dataTypeURL = sql.NullString{Valid: false} 259 dbm.dataValue = nil 260 } 261 ids = append(ids, dbm.messageID) 262 if m.Result != nil && !m.Result.Failed { 263 if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil { 264 return err 265 } 266 if m.Result != nil { 267 mid, _ := common.BytesToMessageID(m.MessageId) 268 if err := d.trySetMessageResult(ctx, tx, mid, m.Result); err != nil { 269 return err 270 } 271 } 272 continue 273 } 274 var processedTime sql.NullInt64 275 var failed sql.NullBool 276 e := tx.QueryRowContext(ctx, "SELECT processed_time_seconds, failed FROM messages where message_id=?", dbm.messageID).Scan(&processedTime, &failed) 277 switch { 278 case e == sql.ErrNoRows: 279 // Common case. Message not yet present, store as normal. 280 if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil { 281 return err 282 } 283 case e != nil: 284 return e 285 case processedTime.Valid && (!failed.Valid || !failed.Bool): 286 // Message previously successfully processed, ignore this reprocessing. 287 case m.Result != nil && (!processedTime.Valid || !m.Result.Failed): 288 mid, err := common.BytesToMessageID(m.MessageId) 289 if err != nil { 290 return err 291 } 292 // Message not previously successfully processed, but this try succeeded. Mark as processed. 293 if err := d.trySetMessageResult(ctx, tx, mid, m.Result); err != nil { 294 return err 295 } 296 default: 297 // The message is already present, but unprocessed/failed, and this 298 // processing didn't succeed or is ongoing. Nothing to do. 299 } 300 } 301 302 if contact == "" { 303 return nil 304 } 305 306 c, err := strconv.ParseUint(string(contact), 16, 64) 307 if err != nil { 308 e := fmt.Errorf("unable to parse ContactID [%v]: %v", contact, err) 309 log.Error(e) 310 return e 311 } 312 for _, id := range ids { 313 if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO client_contact_messages(client_contact_id, message_id) VALUES (?, ?)", c, id); err != nil { 314 return err 315 } 316 } 317 return nil 318 }) 319 } 320 321 func (d *Datastore) tryStoreMessage(ctx context.Context, tx *sql.Tx, dbm *dbMessage, isBroadcast bool) error { 322 if dbm.creationTimeSeconds == 0 { 323 return errors.New("message CreationTime must be set") 324 } 325 res, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO messages("+ 326 "message_id, "+ 327 "source_client_id, "+ 328 "source_service_name, "+ 329 "source_message_id, "+ 330 "destination_client_id, "+ 331 "destination_service_name, "+ 332 "message_type, "+ 333 "creation_time_seconds, "+ 334 "creation_time_nanos, "+ 335 "processed_time_seconds, "+ 336 "processed_time_nanos, "+ 337 "failed,"+ 338 "failed_reason,"+ 339 "validation_info,"+ 340 "annotations) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 341 dbm.messageID, 342 dbm.sourceClientID, 343 dbm.sourceServiceName, 344 dbm.sourceMessageID, 345 dbm.destinationClientID, 346 dbm.destinationServiceName, 347 dbm.messageType, 348 dbm.creationTimeSeconds, 349 dbm.creationTimeNanos, 350 dbm.processedTimeSeconds, 351 dbm.processedTimeNanos, 352 dbm.failed, 353 dbm.failedReason, 354 dbm.validationInfo, 355 dbm.annotations) 356 if err != nil { 357 return err 358 } 359 cnt, err := res.RowsAffected() 360 if err != nil { 361 return err 362 } 363 inserted := cnt == 1 364 if inserted && !dbm.processedTimeSeconds.Valid { 365 var due int64 366 if dbm.destinationClientID == "" { 367 due = toMicro(db.ServerRetryTime(0)) 368 } else { 369 // If this is being created in response to a broadcast, then we about to 370 // hand it to the client and should wait before providing in through 371 // ClientMessagesForProcessing. Otherwise, we should give it to the client 372 // on next contact. 373 if isBroadcast { 374 due = toMicro(db.ClientRetryTime()) 375 } else { 376 due = toMicro(db.Now()) 377 } 378 } 379 _, err = tx.ExecContext(ctx, "INSERT INTO pending_messages("+ 380 "message_id, "+ 381 "retry_count, "+ 382 "scheduled_time, "+ 383 "data_type_url, "+ 384 "data_value) VALUES(?, ?, ?, ?, ?)", 385 dbm.messageID, 386 0, 387 due, 388 dbm.dataTypeURL, 389 dbm.dataValue) 390 if err != nil { 391 return err 392 } 393 } 394 return nil 395 } 396 397 func genPlaceholders(num int) string { 398 es := make([]string, num) 399 for i := range es { 400 es[i] = "?" 401 } 402 return strings.Join(es, ", ") 403 } 404 405 func (d *Datastore) getPendingMessageRawIds(ctx context.Context, tx *sql.Tx, ids []common.ClientID, offset uint64, limit uint64) ([]string, error) { 406 squery := fmt.Sprintf("SELECT "+ 407 "m.message_id AS message_id "+ 408 "FROM messages AS m, pending_messages AS pm "+ 409 "WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id "+ 410 "ORDER BY message_id ", 411 genPlaceholders((len(ids)))) 412 413 if offset != 0 && limit == 0 { 414 return nil, fmt.Errorf("if offset is provided, a limit must be provided as well") 415 } 416 417 if limit != 0 { 418 squery += " LIMIT ?" 419 } 420 421 if offset != 0 { 422 squery += " OFFSET ?" 423 } 424 425 args := make([]any, len(ids), len(ids)+2) 426 for i, v := range ids { 427 args[i] = v.String() 428 } 429 430 if limit != 0 { 431 args = append(args, limit) 432 } 433 434 if offset != 0 { 435 args = append(args, offset) 436 } 437 438 idsToProc := make([]string, 0) 439 440 rs, err := tx.QueryContext(ctx, squery, args...) 441 if err != nil { 442 return nil, fmt.Errorf("Failed to fetch the list of messages to delete: %v", err) 443 } 444 defer rs.Close() 445 for rs.Next() { 446 var id string 447 if err := rs.Scan(&id); err != nil { 448 return nil, err 449 } 450 idsToProc = append(idsToProc, id) 451 } 452 453 return idsToProc, nil 454 } 455 456 func (d *Datastore) GetPendingMessageCount(ctx context.Context, ids []common.ClientID) (uint64, error) { 457 var result uint64 458 459 err := d.runInTx(func(tx *sql.Tx) error { 460 squery := fmt.Sprintf("SELECT "+ 461 "COUNT(*) "+ 462 "FROM messages AS m, pending_messages AS pm "+ 463 "WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id ", 464 genPlaceholders((len(ids)))) 465 466 args := make([]any, len(ids)) 467 for i, v := range ids { 468 args[i] = v.String() 469 } 470 rs, err := tx.QueryContext(ctx, squery, args...) 471 if err != nil { 472 return fmt.Errorf("Failed to fetch the pending message count: %v", err) 473 } 474 defer rs.Close() 475 if !rs.Next() { 476 return fmt.Errorf("Got empty result") 477 } 478 err = rs.Scan(&result) 479 if err != nil { 480 return fmt.Errorf("Failed to scan result: %v", err) 481 } 482 return nil 483 }) 484 485 return result, err 486 } 487 488 func (d *Datastore) GetPendingMessages(ctx context.Context, ids []common.ClientID, offset uint64, limit uint64, wantData bool) ([]*fspb.Message, error) { 489 var res []*fspb.Message 490 err := d.runInTx(func(tx *sql.Tx) error { 491 messageIdsRaw, err := d.getPendingMessageRawIds(ctx, tx, ids, offset, limit) 492 if err != nil { 493 return err 494 } 495 var messageIds []common.MessageID 496 for _, idRaw := range messageIdsRaw { 497 messageID, err := common.StringToMessageID(idRaw) 498 if err != nil { 499 return err 500 } 501 messageIds = append(messageIds, messageID) 502 } 503 res, err = d.getMessages(ctx, tx, messageIds, wantData) 504 return err 505 }) 506 return res, err 507 } 508 509 func (d *Datastore) DeletePendingMessages(ctx context.Context, ids []common.ClientID) error { 510 return d.runInTx(func(tx *sql.Tx) error { 511 messageIds, err := d.getPendingMessageRawIds(ctx, tx, ids, 0, 0) 512 if err != nil { 513 return err 514 } 515 516 idsToProc := make([]any, len(messageIds)) 517 for i, id := range messageIds { 518 idsToProc[i] = id 519 } 520 521 // If there are no messages to be deleted, just bail out. 522 if len(idsToProc) == 0 { 523 return nil 524 } 525 526 now := db.NowProto() 527 ptimeSecs := sql.NullInt64{Valid: true, Int64: now.Seconds} 528 ptimeNanoSecs := sql.NullInt64{Valid: true, Int64: int64(now.Nanos)} 529 failed := sql.NullBool{Valid: true, Bool: true} 530 failedReason := sql.NullString{Valid: true, String: "Removed by admin action."} 531 532 ps := genPlaceholders(len(idsToProc)) 533 uquery := fmt.Sprintf("UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id IN (%s)", ps) 534 _, err = tx.ExecContext(ctx, uquery, append([]any{failed, failedReason, ptimeSecs, ptimeNanoSecs}, idsToProc...)...) 535 if err != nil { 536 return err 537 } 538 539 dquery := fmt.Sprintf("DELETE FROM pending_messages WHERE message_id IN (%s)", ps) 540 _, err = tx.ExecContext(ctx, dquery, idsToProc...) 541 542 return err 543 }) 544 } 545 546 func (d *Datastore) getMessages(ctx context.Context, tx *sql.Tx, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) { 547 d.l.Lock() 548 defer d.l.Unlock() 549 res := make([]*fspb.Message, 0, len(ids)) 550 551 stmt1, err := tx.Prepare("SELECT " + 552 "message_id, " + 553 "source_client_id, " + 554 "source_service_name, " + 555 "source_message_id, " + 556 "destination_client_id, " + 557 "destination_service_name, " + 558 "message_type, " + 559 "creation_time_seconds, " + 560 "creation_time_nanos, " + 561 "processed_time_seconds, " + 562 "processed_time_nanos, " + 563 "validation_info, " + 564 "annotations " + 565 "FROM messages WHERE message_id=?") 566 var stmt2 *sql.Stmt 567 if wantData { 568 stmt2, err = tx.Prepare("SELECT data_type_url, data_value FROM pending_messages WHERE message_id=?") 569 if err != nil { 570 return nil, err 571 } 572 } 573 if err != nil { 574 return nil, err 575 } 576 for _, id := range ids { 577 row := stmt1.QueryRowContext(ctx, id.String()) 578 var dbm dbMessage 579 err := row.Scan( 580 &dbm.messageID, 581 &dbm.sourceClientID, 582 &dbm.sourceServiceName, 583 &dbm.sourceMessageID, 584 &dbm.destinationClientID, 585 &dbm.destinationServiceName, 586 &dbm.messageType, 587 &dbm.creationTimeSeconds, 588 &dbm.creationTimeNanos, 589 &dbm.processedTimeSeconds, 590 &dbm.processedTimeNanos, 591 &dbm.validationInfo, 592 &dbm.annotations) 593 if err != nil { 594 return nil, err 595 } 596 if wantData { 597 row := stmt2.QueryRowContext(ctx, id.String()) 598 err := row.Scan(&dbm.dataTypeURL, &dbm.dataValue) 599 if err != nil && err != sql.ErrNoRows { 600 return nil, err 601 } 602 } 603 m, err := toMessageProto(&dbm) 604 if err != nil { 605 return nil, err 606 } 607 res = append(res, m) 608 } 609 610 return res, nil 611 } 612 613 func (d *Datastore) GetMessages(ctx context.Context, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) { 614 var res []*fspb.Message 615 err := d.runInTx(func(tx *sql.Tx) error { 616 var err error 617 res, err = d.getMessages(ctx, tx, ids, wantData) 618 return err 619 }) 620 return res, err 621 } 622 623 func (d *Datastore) GetMessageResult(ctx context.Context, id common.MessageID) (*fspb.MessageResult, error) { 624 d.l.Lock() 625 defer d.l.Unlock() 626 627 var ret *fspb.MessageResult 628 629 err := d.runInTx(func(tx *sql.Tx) error { 630 row := tx.QueryRowContext(ctx, "SELECT "+ 631 "creation_time_seconds, "+ 632 "creation_time_nanos, "+ 633 "processed_time_seconds, "+ 634 "processed_time_nanos, "+ 635 "failed, "+ 636 "failed_reason "+ 637 "FROM messages WHERE message_id=?", id.String()) 638 639 var dbm dbMessage 640 if err := row.Scan( 641 &dbm.creationTimeSeconds, 642 &dbm.creationTimeNanos, 643 &dbm.processedTimeSeconds, 644 &dbm.processedTimeNanos, 645 &dbm.failed, 646 &dbm.failedReason, 647 ); err == sql.ErrNoRows { 648 return nil 649 } else if err != nil { 650 return err 651 } 652 653 ret = toMessageResultProto(&dbm) 654 return nil 655 }) 656 657 return ret, err 658 } 659 660 // ClientMessagesForProcessing implements db.MessageStore. 661 func (d *Datastore) ClientMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) { 662 if id == (common.ClientID{}) { 663 return nil, errors.New("a client is required") 664 } 665 return d.internalMessagesForProcessing(ctx, id, lim, serviceLimits) 666 } 667 668 func (d *Datastore) internalMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) { 669 d.l.Lock() 670 defer d.l.Unlock() 671 672 read := make(map[string]uint64) 673 674 var res []*fspb.Message 675 676 if err := d.runInTx(func(tx *sql.Tx) error { 677 // As an internal addition to the MessageStore interface, this 678 // also gets server messages when id=ClientID{}. 679 rs, err := tx.QueryContext(ctx, "SELECT "+ 680 "m.message_id, "+ 681 "m.source_client_id, "+ 682 "m.source_service_name, "+ 683 "m.source_message_id, "+ 684 "m.destination_client_id, "+ 685 "m.destination_service_name, "+ 686 "m.message_type, "+ 687 "m.creation_time_seconds, "+ 688 "m.creation_time_nanos,"+ 689 "m.validation_info,"+ 690 "m.annotations,"+ 691 "pm.retry_count, "+ 692 "pm.data_type_url, "+ 693 "pm.data_value "+ 694 "FROM messages AS m, pending_messages AS pm "+ 695 "WHERE m.destination_client_id = ? AND m.message_id=pm.message_id AND pm.scheduled_time < ? ", 696 toClientIDString(id.Bytes()), toMicro(db.Now())) 697 if err != nil { 698 return err 699 } 700 defer rs.Close() 701 for rs.Next() { 702 var dbm dbMessage 703 if err = rs.Scan( 704 &dbm.messageID, 705 &dbm.sourceClientID, 706 &dbm.sourceServiceName, 707 &dbm.sourceMessageID, 708 &dbm.destinationClientID, 709 &dbm.destinationServiceName, 710 &dbm.messageType, 711 &dbm.creationTimeSeconds, 712 &dbm.creationTimeNanos, 713 &dbm.validationInfo, 714 &dbm.annotations, 715 &dbm.retryCount, 716 &dbm.dataTypeURL, 717 &dbm.dataValue, 718 ); err != nil { 719 return err 720 } 721 if serviceLimits != nil { 722 if read[dbm.destinationServiceName] >= serviceLimits[dbm.destinationServiceName] { 723 continue 724 } else { 725 read[dbm.destinationServiceName]++ 726 } 727 } 728 nc := dbm.retryCount + 1 729 var due int64 730 if dbm.destinationClientID == "" { 731 due = toMicro(db.ServerRetryTime(nc)) 732 } else { 733 due = toMicro(db.ClientRetryTime()) 734 } 735 if _, err = tx.ExecContext(ctx, "UPDATE pending_messages SET retry_count=?, scheduled_time=? WHERE message_id=?", nc, due, dbm.messageID); err != nil { 736 return err 737 } 738 m, err := toMessageProto(&dbm) 739 if err != nil { 740 return err 741 } 742 res = append(res, m) 743 if len(res) >= int(lim) { 744 return nil 745 } 746 } 747 return rs.Err() 748 }); err != nil { 749 return nil, err 750 } 751 return res, nil 752 } 753 754 type messageLooper struct { 755 d *Datastore 756 757 mp db.MessageProcessor 758 processingTicker *time.Ticker 759 stopCalled chan struct{} 760 loopDone chan struct{} 761 } 762 763 func (d *Datastore) RegisterMessageProcessor(mp db.MessageProcessor) { 764 if d.looper != nil { 765 log.Warning("Attempt to register a second MessageProcessor.") 766 d.looper.stop() 767 } 768 d.looper = &messageLooper{ 769 d: d, 770 mp: mp, 771 processingTicker: time.NewTicker(300 * time.Millisecond), 772 stopCalled: make(chan struct{}), 773 loopDone: make(chan struct{}), 774 } 775 go d.looper.messageProcessingLoop() 776 } 777 778 func (d *Datastore) StopMessageProcessor() { 779 if d.looper != nil { 780 d.looper.stop() 781 } 782 d.looper = nil 783 } 784 785 // messageProcessingLoop reads messages that should be processed on the server 786 // from the datastore and delivers them to the registered MessageProcessor. 787 func (l *messageLooper) messageProcessingLoop() { 788 defer close(l.loopDone) 789 for { 790 select { 791 case <-l.stopCalled: 792 return 793 case <-l.processingTicker.C: 794 l.processMessages() 795 } 796 } 797 } 798 799 func (l *messageLooper) stop() { 800 l.processingTicker.Stop() 801 close(l.stopCalled) 802 <-l.loopDone 803 } 804 805 func (l *messageLooper) processMessages() { 806 for { 807 msgs, err := l.d.internalMessagesForProcessing(context.Background(), common.ClientID{}, 5, nil) 808 if err != nil { 809 if err.Error() == "attempt to write a readonly database" { 810 log.Errorf("Failed to read server messages for processing; probably the database was removed: %v", err) 811 return 812 } 813 814 log.Errorf("Failed to read server messages for processing: %v", err) 815 continue 816 } 817 l.mp.ProcessMessages(msgs) 818 if len(msgs) == 0 { 819 return 820 } 821 } 822 }