github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/mysql/messagestore.go (about) 1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql 16 17 import ( 18 "context" 19 "database/sql" 20 "encoding/hex" 21 "errors" 22 "fmt" 23 "math/rand" 24 "strconv" 25 "strings" 26 "time" 27 28 log "github.com/golang/glog" 29 tspb "google.golang.org/protobuf/types/known/timestamppb" 30 31 "github.com/google/fleetspeak/fleetspeak/src/common" 32 "github.com/google/fleetspeak/fleetspeak/src/server/db" 33 34 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 35 "google.golang.org/protobuf/proto" 36 anypb "google.golang.org/protobuf/types/known/anypb" 37 ) 38 39 // dbMessage matches the schema of the messages table, optionally joined to the 40 // pending_messages table. 41 type dbMessage struct { 42 messageID []byte 43 sourceClientID []byte 44 sourceServiceName string 45 sourceMessageID []byte 46 destinationClientID []byte 47 destinationServiceName string 48 messageType string 49 creationTimeSeconds int64 50 creationTimeNanos int32 51 processedTimeSeconds sql.NullInt64 52 processedTimeNanos sql.NullInt64 53 validationInfo []byte 54 failed sql.NullBool 55 failedReason sql.NullString 56 retryCount uint32 57 dataTypeURL sql.NullString 58 dataValue []byte 59 annotations []byte 60 } 61 62 func toMicro(t time.Time) int64 { 63 return t.UnixNano() / 1000 64 } 65 66 func (d *Datastore) SetMessageResult(ctx context.Context, dest common.ClientID, id common.MessageID, res *fspb.MessageResult) error { 67 return d.runInTx(ctx, false, func(tx *sql.Tx) error { return d.trySetMessageResult(ctx, tx, dest.IsNil(), id, res) }) 68 } 69 70 func (d *Datastore) trySetMessageResult(ctx context.Context, tx *sql.Tx, forServer bool, id common.MessageID, res *fspb.MessageResult) error { 71 dbm := dbMessage{ 72 messageID: id.Bytes(), 73 processedTimeSeconds: sql.NullInt64{Valid: true, Int64: res.ProcessedTime.Seconds}, 74 processedTimeNanos: sql.NullInt64{Valid: true, Int64: int64(res.ProcessedTime.Nanos)}, 75 } 76 if res.Failed { 77 dbm.failed = sql.NullBool{Valid: true, Bool: true} 78 dbm.failedReason = sql.NullString{Valid: true, String: res.FailedReason} 79 } 80 _, err := tx.ExecContext(ctx, "UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id=?", 81 dbm.failed, dbm.failedReason, dbm.processedTimeSeconds, dbm.processedTimeNanos, dbm.messageID) 82 if err != nil { 83 return err 84 } 85 var fs int 86 if forServer { 87 fs = 1 88 } 89 _, err = tx.ExecContext(ctx, "DELETE FROM pending_messages WHERE for_server=? AND message_id=?", fs, dbm.messageID) 90 return err 91 } 92 93 func fromNULLString(s sql.NullString) string { 94 if !s.Valid { 95 return "" 96 } 97 return s.String 98 } 99 100 func fromMessageProto(m *fspb.Message) (*dbMessage, error) { 101 id, err := common.BytesToMessageID(m.MessageId) 102 if err != nil { 103 return nil, err 104 } 105 dbm := &dbMessage{ 106 messageID: id.Bytes(), 107 messageType: m.MessageType, 108 } 109 if m.Source != nil { 110 dbm.sourceClientID = m.Source.ClientId 111 dbm.sourceServiceName = m.Source.ServiceName 112 } 113 if m.Destination != nil { 114 dbm.destinationClientID = m.Destination.ClientId 115 dbm.destinationServiceName = m.Destination.ServiceName 116 } 117 if len(m.SourceMessageId) != 0 { 118 dbm.sourceMessageID = m.SourceMessageId 119 } 120 if m.CreationTime != nil { 121 dbm.creationTimeSeconds = m.CreationTime.Seconds 122 dbm.creationTimeNanos = m.CreationTime.Nanos 123 } 124 if m.Result != nil { 125 r := m.Result 126 if r.ProcessedTime != nil { 127 dbm.processedTimeSeconds = sql.NullInt64{Int64: r.ProcessedTime.Seconds, Valid: true} 128 dbm.processedTimeNanos = sql.NullInt64{Int64: int64(r.ProcessedTime.Nanos), Valid: true} 129 } 130 if r.Failed { 131 dbm.failed = sql.NullBool{Bool: true, Valid: true} 132 dbm.failedReason = sql.NullString{String: r.FailedReason, Valid: true} 133 } 134 } 135 if m.Data != nil { 136 dbm.dataTypeURL = sql.NullString{String: m.Data.TypeUrl, Valid: true} 137 dbm.dataValue = m.Data.Value 138 } 139 if m.ValidationInfo != nil { 140 b, err := proto.Marshal(m.ValidationInfo) 141 if err != nil { 142 return nil, err 143 } 144 dbm.validationInfo = b 145 } 146 if m.Annotations != nil { 147 b, err := proto.Marshal(m.Annotations) 148 if err != nil { 149 return nil, err 150 } 151 dbm.annotations = b 152 } 153 return dbm, nil 154 } 155 156 func toMessageResultProto(m *dbMessage) *fspb.MessageResult { 157 if !m.processedTimeSeconds.Valid { 158 return nil 159 } 160 161 ret := &fspb.MessageResult{ 162 ProcessedTime: &tspb.Timestamp{ 163 Seconds: m.processedTimeSeconds.Int64, 164 Nanos: int32(m.processedTimeNanos.Int64)}, 165 Failed: m.failed.Valid && m.failed.Bool, 166 } 167 168 if m.failedReason.Valid { 169 ret.FailedReason = m.failedReason.String 170 } 171 return ret 172 } 173 174 func toMessageProto(m *dbMessage) (*fspb.Message, error) { 175 mid, err := common.BytesToMessageID(m.messageID) 176 if err != nil { 177 return nil, err 178 } 179 pm := &fspb.Message{ 180 MessageId: mid.Bytes(), 181 Source: &fspb.Address{ 182 ClientId: m.sourceClientID, 183 ServiceName: m.sourceServiceName, 184 }, 185 SourceMessageId: m.sourceMessageID, 186 Destination: &fspb.Address{ 187 ClientId: m.destinationClientID, 188 ServiceName: m.destinationServiceName, 189 }, 190 MessageType: m.messageType, 191 CreationTime: &tspb.Timestamp{ 192 Seconds: m.creationTimeSeconds, 193 Nanos: m.creationTimeNanos, 194 }, 195 Result: toMessageResultProto(m), 196 } 197 if m.dataTypeURL.Valid { 198 pm.Data = &anypb.Any{ 199 TypeUrl: m.dataTypeURL.String, 200 Value: m.dataValue, 201 } 202 } 203 if len(m.validationInfo) > 0 { 204 v := &fspb.ValidationInfo{} 205 if err := proto.Unmarshal(m.validationInfo, v); err != nil { 206 return nil, err 207 } 208 pm.ValidationInfo = v 209 } 210 if len(m.annotations) > 0 { 211 a := &fspb.Annotations{} 212 if err := proto.Unmarshal(m.annotations, a); err != nil { 213 return nil, err 214 } 215 pm.Annotations = a 216 } 217 return pm, nil 218 } 219 220 func genPlaceholders(num int) string { 221 es := make([]string, num) 222 for i := range es { 223 es[i] = "?" 224 } 225 return strings.Join(es, ", ") 226 } 227 228 func (d *Datastore) getPendingMessageRawIds(ctx context.Context, tx *sql.Tx, ids []common.ClientID, offset uint64, limit uint64, forUpdate bool) ([][]byte, error) { 229 squery := fmt.Sprintf("SELECT "+ 230 "m.message_id AS message_id "+ 231 "FROM messages AS m, pending_messages AS pm "+ 232 "WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id "+ 233 "ORDER BY message_id", 234 genPlaceholders((len(ids)))) 235 236 if offset != 0 && limit == 0 { 237 return nil, fmt.Errorf("if offset is provided, a limit must be provided as well") 238 } 239 240 if limit != 0 { 241 squery += " LIMIT ?" 242 } 243 244 if offset != 0 { 245 squery += " OFFSET ?" 246 } 247 248 if forUpdate { 249 squery += " FOR UPDATE" 250 } 251 252 args := make([]any, len(ids), len(ids)+2) 253 for i, v := range ids { 254 args[i] = v.Bytes() 255 } 256 257 if limit != 0 { 258 args = append(args, limit) 259 } 260 261 if offset != 0 { 262 args = append(args, offset) 263 } 264 265 idsToProc := make([][]byte, 0) 266 267 rs, err := tx.QueryContext(ctx, squery, args...) 268 if err != nil { 269 return nil, fmt.Errorf("Failed to fetch the list of pending messages: %v", err) 270 } 271 defer rs.Close() 272 for rs.Next() { 273 var id []byte 274 if err := rs.Scan(&id); err != nil { 275 return nil, err 276 } 277 idsToProc = append(idsToProc, id) 278 } 279 280 return idsToProc, nil 281 } 282 283 func (d *Datastore) GetPendingMessageCount(ctx context.Context, ids []common.ClientID) (uint64, error) { 284 var result uint64 285 286 err := d.runInTx(ctx, true, func(tx *sql.Tx) error { 287 squery := fmt.Sprintf("SELECT "+ 288 "COUNT(*) "+ 289 "FROM messages AS m, pending_messages AS pm "+ 290 "WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id ", 291 genPlaceholders((len(ids)))) 292 293 args := make([]any, len(ids)) 294 for i, v := range ids { 295 args[i] = v.Bytes() 296 } 297 rs, err := tx.QueryContext(ctx, squery, args...) 298 if err != nil { 299 return fmt.Errorf("Failed to fetch the pending message count: %v", err) 300 } 301 defer rs.Close() 302 if !rs.Next() { 303 return fmt.Errorf("Got empty result") 304 } 305 err = rs.Scan(&result) 306 if err != nil { 307 return fmt.Errorf("Failed to scan result: %v", err) 308 } 309 return nil 310 }) 311 312 return result, err 313 } 314 315 func (d *Datastore) GetPendingMessages(ctx context.Context, ids []common.ClientID, offset uint64, count uint64, wantData bool) ([]*fspb.Message, error) { 316 var res []*fspb.Message 317 err := d.runInTx(ctx, true, func(tx *sql.Tx) error { 318 messageIdsRaw, err := d.getPendingMessageRawIds(ctx, tx, ids, offset, count, false) 319 if err != nil { 320 return err 321 } 322 var messageIds []common.MessageID 323 for _, idRaw := range messageIdsRaw { 324 messageID, err := common.BytesToMessageID(idRaw) 325 if err != nil { 326 return err 327 } 328 messageIds = append(messageIds, messageID) 329 } 330 res, err = d.getMessages(ctx, tx, messageIds, wantData) 331 return err 332 }) 333 return res, err 334 } 335 336 func (d *Datastore) DeletePendingMessages(ctx context.Context, ids []common.ClientID) error { 337 return d.runInTx(ctx, false, func(tx *sql.Tx) error { 338 messageIds, err := d.getPendingMessageRawIds(ctx, tx, ids, 0, 0, true) 339 if err != nil { 340 return err 341 } 342 343 idsToProc := make([]any, len(messageIds)) 344 for i, id := range messageIds { 345 idsToProc[i] = id 346 } 347 348 // If there are no messages to be deleted, just bail out. 349 if len(idsToProc) == 0 { 350 return nil 351 } 352 353 now := db.NowProto() 354 ptimeSecs := sql.NullInt64{Valid: true, Int64: now.Seconds} 355 ptimeNanoSecs := sql.NullInt64{Valid: true, Int64: int64(now.Nanos)} 356 failed := sql.NullBool{Valid: true, Bool: true} 357 failedReason := sql.NullString{Valid: true, String: "Removed by admin action."} 358 359 ps := genPlaceholders(len(idsToProc)) 360 uquery := fmt.Sprintf("UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id IN (%s)", ps) 361 _, err = tx.ExecContext(ctx, uquery, append([]any{failed, failedReason, ptimeSecs, ptimeNanoSecs}, idsToProc...)...) 362 if err != nil { 363 return err 364 } 365 366 dquery := fmt.Sprintf("DELETE FROM pending_messages WHERE message_id IN (%s)", ps) 367 _, err = tx.ExecContext(ctx, dquery, idsToProc...) 368 369 return err 370 }) 371 } 372 373 func (d *Datastore) StoreMessages(ctx context.Context, msgs []*fspb.Message, contact db.ContactID) error { 374 return d.runInTx(ctx, false, func(tx *sql.Tx) error { 375 ids := make([][]byte, 0, len(msgs)) 376 for _, m := range msgs { 377 dbm, err := fromMessageProto(m) 378 if err != nil { 379 return err 380 } 381 // If it is already processed, we don't want to save m.Data. 382 if m.Result != nil { 383 dbm.dataTypeURL = sql.NullString{Valid: false} 384 dbm.dataValue = nil 385 } 386 ids = append(ids, dbm.messageID) 387 if m.Result != nil && !m.Result.Failed { 388 if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil { 389 return err 390 } 391 if m.Result != nil { 392 mid, _ := common.BytesToMessageID(m.MessageId) 393 if err := d.trySetMessageResult(ctx, tx, m.Destination.ClientId == nil, mid, m.Result); err != nil { 394 return err 395 } 396 } 397 continue 398 } 399 var processedTime sql.NullInt64 400 var failed sql.NullBool 401 e := tx.QueryRowContext(ctx, "SELECT processed_time_seconds, failed FROM messages where message_id=?", dbm.messageID).Scan(&processedTime, &failed) 402 switch { 403 case e == sql.ErrNoRows: 404 // Common case. Message not yet present, store as normal. 405 if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil { 406 return err 407 } 408 case e != nil: 409 return e 410 case processedTime.Valid && (!failed.Valid || !failed.Bool): 411 // Message previously successfully processed, ignore this reprocessing. 412 case m.Result != nil && (!processedTime.Valid || !m.Result.Failed): 413 mid, err := common.BytesToMessageID(m.MessageId) 414 if err != nil { 415 return err 416 } 417 // Message not previously successfully processed, but this try succeeded. Mark as processed. 418 if err := d.trySetMessageResult(ctx, tx, m.Destination.ClientId == nil, mid, m.Result); err != nil { 419 return err 420 } 421 default: 422 // The message is already present, but unprocessed/failed, and this 423 // processing didn't succeed or is ongoing. Nothing to do. 424 } 425 } 426 427 if contact == "" { 428 return nil 429 } 430 431 c, err := strconv.ParseUint(string(contact), 16, 64) 432 if err != nil { 433 e := fmt.Errorf("unable to parse ContactID [%v]: %v", contact, err) 434 log.Error(e) 435 return e 436 } 437 for _, id := range ids { 438 if _, err := tx.ExecContext(ctx, "INSERT IGNORE INTO client_contact_messages(client_contact_id, message_id) VALUES (?, ?)", c, id); err != nil { 439 return err 440 } 441 } 442 return nil 443 }) 444 } 445 446 func (d *Datastore) tryStoreMessage(ctx context.Context, tx *sql.Tx, dbm *dbMessage, isBroadcast bool) error { 447 if dbm.creationTimeSeconds == 0 { 448 return errors.New("message CreationTime must be set") 449 } 450 res, err := tx.ExecContext(ctx, "INSERT IGNORE INTO messages("+ 451 "message_id, "+ 452 "source_client_id, "+ 453 "source_service_name, "+ 454 "source_message_id, "+ 455 "destination_client_id, "+ 456 "destination_service_name, "+ 457 "message_type, "+ 458 "creation_time_seconds, "+ 459 "creation_time_nanos, "+ 460 "processed_time_seconds, "+ 461 "processed_time_nanos, "+ 462 "failed,"+ 463 "failed_reason,"+ 464 "validation_info,"+ 465 "annotations) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 466 dbm.messageID, 467 dbm.sourceClientID, 468 dbm.sourceServiceName, 469 dbm.sourceMessageID, 470 dbm.destinationClientID, 471 dbm.destinationServiceName, 472 dbm.messageType, 473 dbm.creationTimeSeconds, 474 dbm.creationTimeNanos, 475 dbm.processedTimeSeconds, 476 dbm.processedTimeNanos, 477 dbm.failed, 478 dbm.failedReason, 479 dbm.validationInfo, 480 dbm.annotations) 481 if err != nil { 482 return err 483 } 484 cnt, err := res.RowsAffected() 485 if err != nil { 486 return err 487 } 488 inserted := cnt == 1 489 if inserted && !dbm.processedTimeSeconds.Valid { 490 var due int64 491 if dbm.destinationClientID == nil { 492 due = toMicro(db.ServerRetryTime(0)) 493 } else { 494 // If this is being created in response to a broadcast, then we about to 495 // hand it to the client and should wait before providing in through 496 // ClientMessagesForProcessing. Otherwise, we should give it to the client 497 // on next contact. 498 if isBroadcast { 499 due = toMicro(db.ClientRetryTime()) 500 } else { 501 due = toMicro(db.Now()) 502 } 503 } 504 fs := 0 505 if dbm.destinationClientID == nil { 506 fs = 1 507 } 508 _, err = tx.ExecContext(ctx, "INSERT INTO pending_messages("+ 509 "for_server, "+ 510 "message_id, "+ 511 "retry_count, "+ 512 "scheduled_time, "+ 513 "data_type_url, "+ 514 "data_value) VALUES(?, ?, ?, ?, ?, ?)", 515 fs, 516 dbm.messageID, 517 0, 518 due, 519 dbm.dataTypeURL, 520 dbm.dataValue) 521 if err != nil { 522 return err 523 } 524 } 525 return nil 526 } 527 528 func (d *Datastore) getMessages(ctx context.Context, tx *sql.Tx, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) { 529 res := make([]*fspb.Message, 0, len(ids)) 530 stmt1, err := tx.Prepare("SELECT " + 531 "message_id, " + 532 "source_client_id, " + 533 "source_service_name, " + 534 "source_message_id, " + 535 "destination_client_id, " + 536 "destination_service_name, " + 537 "message_type, " + 538 "creation_time_seconds, " + 539 "creation_time_nanos, " + 540 "processed_time_seconds, " + 541 "processed_time_nanos, " + 542 "validation_info, " + 543 "annotations " + 544 "FROM messages WHERE message_id=?") 545 var stmt2 *sql.Stmt 546 if wantData { 547 stmt2, err = tx.Prepare("SELECT data_type_url, data_value FROM pending_messages WHERE message_id=?") 548 if err != nil { 549 return nil, err 550 } 551 } 552 if err != nil { 553 return nil, err 554 } 555 for _, id := range ids { 556 row := stmt1.QueryRowContext(ctx, id.Bytes()) 557 var dbm dbMessage 558 err := row.Scan( 559 &dbm.messageID, 560 &dbm.sourceClientID, 561 &dbm.sourceServiceName, 562 &dbm.sourceMessageID, 563 &dbm.destinationClientID, 564 &dbm.destinationServiceName, 565 &dbm.messageType, 566 &dbm.creationTimeSeconds, 567 &dbm.creationTimeNanos, 568 &dbm.processedTimeSeconds, 569 &dbm.processedTimeNanos, 570 &dbm.validationInfo, 571 &dbm.annotations) 572 if err != nil { 573 return nil, err 574 } 575 if wantData { 576 row := stmt2.QueryRowContext(ctx, id.Bytes()) 577 err := row.Scan(&dbm.dataTypeURL, &dbm.dataValue) 578 if err != nil && err != sql.ErrNoRows { 579 return nil, err 580 } 581 } 582 m, err := toMessageProto(&dbm) 583 if err != nil { 584 return nil, err 585 } 586 res = append(res, m) 587 } 588 return res, nil 589 } 590 591 func (d *Datastore) GetMessages(ctx context.Context, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) { 592 var res []*fspb.Message 593 err := d.runInTx(ctx, true, func(tx *sql.Tx) error { 594 var err error 595 res, err = d.getMessages(ctx, tx, ids, wantData) 596 return err 597 }) 598 return res, err 599 } 600 601 func (d *Datastore) GetMessageResult(ctx context.Context, id common.MessageID) (*fspb.MessageResult, error) { 602 var ret *fspb.MessageResult 603 604 err := d.runInTx(ctx, true, func(tx *sql.Tx) error { 605 row := tx.QueryRowContext(ctx, "SELECT "+ 606 "creation_time_seconds, "+ 607 "creation_time_nanos, "+ 608 "processed_time_seconds, "+ 609 "processed_time_nanos, "+ 610 "failed, "+ 611 "failed_reason "+ 612 "FROM messages WHERE message_id=?", id.Bytes()) 613 614 var dbm dbMessage 615 if err := row.Scan( 616 &dbm.creationTimeSeconds, 617 &dbm.creationTimeNanos, 618 &dbm.processedTimeSeconds, 619 &dbm.processedTimeNanos, 620 &dbm.failed, 621 &dbm.failedReason, 622 ); err == sql.ErrNoRows { 623 return nil 624 } else if err != nil { 625 return err 626 } 627 628 ret = toMessageResultProto(&dbm) 629 return nil 630 }) 631 632 return ret, err 633 } 634 635 func (d *Datastore) ClientMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) { 636 if id == (common.ClientID{}) { 637 return nil, errors.New("a client is required") 638 } 639 return d.internalClientMessagesForProcessing(ctx, id, lim, serviceLimits) 640 } 641 642 type pendingUpdate struct { 643 id []byte 644 nc uint32 645 due int64 646 } 647 648 func (d *Datastore) internalClientMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) { 649 650 var read map[string]uint64 651 if serviceLimits != nil { 652 read = make(map[string]uint64) 653 } 654 res := make([]*fspb.Message, 0, lim) 655 656 if err := d.runInTx(ctx, false, func(tx *sql.Tx) error { 657 var updates []*pendingUpdate 658 659 // As an internal addition to the MessageStore interface, this 660 // also gets server messages when id=ClientID{}. 661 rs, err := tx.QueryContext(ctx, "SELECT "+ 662 "m.message_id, "+ 663 "m.source_client_id, "+ 664 "m.source_service_name, "+ 665 "m.source_message_id, "+ 666 "m.destination_client_id, "+ 667 "m.destination_service_name, "+ 668 "m.message_type, "+ 669 "m.creation_time_seconds, "+ 670 "m.creation_time_nanos,"+ 671 "m.validation_info,"+ 672 "m.annotations,"+ 673 "pm.retry_count, "+ 674 "pm.data_type_url, "+ 675 "pm.data_value "+ 676 "FROM messages AS m, pending_messages AS pm "+ 677 "WHERE m.destination_client_id = ? AND m.message_id=pm.message_id AND pm.scheduled_time < ? "+ 678 "FOR UPDATE", 679 id.Bytes(), toMicro(db.Now())) 680 if err != nil { 681 return err 682 } 683 defer rs.Close() 684 for rs.Next() { 685 var dbm dbMessage 686 if err := rs.Scan( 687 &dbm.messageID, 688 &dbm.sourceClientID, 689 &dbm.sourceServiceName, 690 &dbm.sourceMessageID, 691 &dbm.destinationClientID, 692 &dbm.destinationServiceName, 693 &dbm.messageType, 694 &dbm.creationTimeSeconds, 695 &dbm.creationTimeNanos, 696 &dbm.validationInfo, 697 &dbm.annotations, 698 &dbm.retryCount, 699 &dbm.dataTypeURL, 700 &dbm.dataValue, 701 ); err != nil { 702 return err 703 } 704 if serviceLimits != nil { 705 if read[dbm.destinationServiceName] >= serviceLimits[dbm.destinationServiceName] { 706 continue 707 } else { 708 read[dbm.destinationServiceName]++ 709 } 710 } 711 712 if len(res) >= int(lim) { 713 continue 714 } 715 716 updates = append(updates, &pendingUpdate{ 717 id: dbm.messageID, 718 nc: dbm.retryCount + 1, 719 due: toMicro(db.ClientRetryTime())}) 720 m, err := toMessageProto(&dbm) 721 if err != nil { 722 return err 723 } 724 res = append(res, m) 725 } 726 if err := rs.Err(); err != nil { 727 return err 728 } 729 for _, u := range updates { 730 if _, err := tx.ExecContext(ctx, "UPDATE pending_messages SET retry_count=?, scheduled_time=? WHERE for_server=0 AND message_id=?", u.nc, u.due, u.id); err != nil { 731 return err 732 } 733 } 734 return nil 735 }); err != nil { 736 return nil, err 737 } 738 return res, nil 739 } 740 741 func (d *Datastore) internalServerMessagesForProcessing(ctx context.Context, lim int) ([]*fspb.Message, error) { 742 743 var dbms []*dbMessage 744 // First a full transaction to read and update the pending_messages table. 745 if err := d.runInTx(ctx, false, func(tx *sql.Tx) error { 746 747 // Use two random bytes to decide where to start reading (and locking). This 748 // reduces the odds of trying to grab the same records as another server 749 // when there is a large backlog. 750 b := [2]byte{} 751 rand.Read(b[:]) 752 start := hex.EncodeToString(b[:]) 753 754 readPending := func(q string, l int) error { 755 rs, err := tx.QueryContext(ctx, q, start, toMicro(db.Now()), l) 756 if err != nil { 757 return err 758 } 759 defer rs.Close() 760 for rs.Next() { 761 var dbm dbMessage 762 if err := rs.Scan( 763 &dbm.messageID, 764 &dbm.retryCount, 765 &dbm.dataTypeURL, 766 &dbm.dataValue, 767 ); err != nil { 768 return err 769 } 770 dbms = append(dbms, &dbm) 771 } 772 return rs.Err() 773 } 774 775 // First query reads from "start" to keyspace end. 776 if err := readPending( 777 "SELECT "+ 778 "message_id, "+ 779 "retry_count, "+ 780 "data_type_url, "+ 781 "data_value "+ 782 "FROM pending_messages "+ 783 "WHERE for_server=1 "+ 784 "AND message_id > ? "+ 785 "AND scheduled_time < ? "+ 786 "ORDER BY message_id "+ 787 "LIMIT ? "+ 788 "FOR UPDATE", lim); err != nil { 789 return err 790 } 791 // If needed, second reads from keyspace begin to "start". 792 if len(dbms) < lim { 793 if err := readPending( 794 "SELECT "+ 795 "message_id, "+ 796 "retry_count, "+ 797 "data_type_url, "+ 798 "data_value "+ 799 "FROM pending_messages "+ 800 "WHERE for_server=1 "+ 801 "AND message_id < ? "+ 802 "AND scheduled_time < ? "+ 803 "ORDER BY message_id "+ 804 "LIMIT ? "+ 805 "FOR UPDATE", lim-len(dbms)); err != nil { 806 return err 807 } 808 } 809 stmt, err := tx.PrepareContext(ctx, "UPDATE pending_messages SET retry_count=?, scheduled_time=? WHERE for_server=1 AND message_id=?") 810 if err != nil { 811 return err 812 } 813 for _, dbm := range dbms { 814 if _, err := stmt.ExecContext(ctx, dbm.retryCount+1, toMicro(db.ServerRetryTime(dbm.retryCount+1)), dbm.messageID); err != nil { 815 return err 816 } 817 } 818 return nil 819 }); err != nil { 820 return nil, err 821 } 822 if dbms == nil { 823 return nil, nil 824 } 825 // Now we read the rest of the message metadata. In a perfect world, we'd 826 // rollback the transaction above if anything here fails, however we do this 827 // outside of the transaction in order to release any locks held by the 828 // transaction asap. 829 res := make([]*fspb.Message, 0, len(dbms)) 830 stmt, err := d.db.PrepareContext(ctx, "SELECT "+ 831 "source_client_id, "+ 832 "source_service_name, "+ 833 "source_message_id, "+ 834 "destination_client_id, "+ 835 "destination_service_name, "+ 836 "message_type, "+ 837 "creation_time_seconds, "+ 838 "creation_time_nanos, "+ 839 "validation_info, "+ 840 "annotations "+ 841 "FROM messages "+ 842 "WHERE message_id = ?") 843 if err != nil { 844 return nil, err 845 } 846 defer stmt.Close() 847 for _, dbm := range dbms { 848 r := stmt.QueryRowContext(ctx, dbm.messageID) 849 if err := r.Scan( 850 &dbm.sourceClientID, 851 &dbm.sourceServiceName, 852 &dbm.sourceMessageID, 853 &dbm.destinationClientID, 854 &dbm.destinationServiceName, 855 &dbm.messageType, 856 &dbm.creationTimeSeconds, 857 &dbm.creationTimeNanos, 858 &dbm.validationInfo, 859 &dbm.annotations, 860 ); err != nil { 861 return nil, err 862 } 863 m, err := toMessageProto(dbm) 864 if err != nil { 865 return nil, err 866 } 867 res = append(res, m) 868 } 869 return res, nil 870 } 871 872 type messageLooper struct { 873 d *Datastore 874 875 mp db.MessageProcessor 876 processingTicker *time.Ticker 877 stopCalled chan struct{} 878 loopDone chan struct{} 879 } 880 881 func (d *Datastore) RegisterMessageProcessor(mp db.MessageProcessor) { 882 if d.looper != nil { 883 log.Warning("Attempt to register a second MessageProcessor.") 884 d.looper.stop() 885 } 886 d.looper = &messageLooper{ 887 d: d, 888 mp: mp, 889 processingTicker: time.NewTicker(300 * time.Millisecond), 890 stopCalled: make(chan struct{}), 891 loopDone: make(chan struct{}), 892 } 893 go d.looper.messageProcessingLoop() 894 } 895 896 func (d *Datastore) StopMessageProcessor() { 897 if d.looper != nil { 898 d.looper.stop() 899 } 900 d.looper = nil 901 } 902 903 // messageProcessingLoop reads messages that should be processed on the server 904 // from the datastore and delivers them to the registered MessageProcessor. 905 func (l *messageLooper) messageProcessingLoop() { 906 defer close(l.loopDone) 907 for { 908 select { 909 case <-l.stopCalled: 910 return 911 case <-l.processingTicker.C: 912 l.processMessages() 913 } 914 } 915 } 916 917 func (l *messageLooper) stop() { 918 l.processingTicker.Stop() 919 close(l.stopCalled) 920 <-l.loopDone 921 } 922 923 func (l *messageLooper) processMessages() { 924 for { 925 msgs, err := l.d.internalServerMessagesForProcessing(context.Background(), 10) 926 if err != nil { 927 log.Errorf("Failed to read server messages for processing: %v", err) 928 continue 929 } 930 l.mp.ProcessMessages(msgs) 931 if len(msgs) == 0 { 932 return 933 } 934 } 935 }