github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/mysql/messagestore.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"encoding/hex"
    21  	"errors"
    22  	"fmt"
    23  	"math/rand"
    24  	"strconv"
    25  	"strings"
    26  	"time"
    27  
    28  	log "github.com/golang/glog"
    29  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    30  
    31  	"github.com/google/fleetspeak/fleetspeak/src/common"
    32  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    33  
    34  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    35  	"google.golang.org/protobuf/proto"
    36  	anypb "google.golang.org/protobuf/types/known/anypb"
    37  )
    38  
    39  // dbMessage matches the schema of the messages table, optionally joined to the
    40  // pending_messages table.
    41  type dbMessage struct {
    42  	messageID              []byte
    43  	sourceClientID         []byte
    44  	sourceServiceName      string
    45  	sourceMessageID        []byte
    46  	destinationClientID    []byte
    47  	destinationServiceName string
    48  	messageType            string
    49  	creationTimeSeconds    int64
    50  	creationTimeNanos      int32
    51  	processedTimeSeconds   sql.NullInt64
    52  	processedTimeNanos     sql.NullInt64
    53  	validationInfo         []byte
    54  	failed                 sql.NullBool
    55  	failedReason           sql.NullString
    56  	retryCount             uint32
    57  	dataTypeURL            sql.NullString
    58  	dataValue              []byte
    59  	annotations            []byte
    60  }
    61  
    62  func toMicro(t time.Time) int64 {
    63  	return t.UnixNano() / 1000
    64  }
    65  
    66  func (d *Datastore) SetMessageResult(ctx context.Context, dest common.ClientID, id common.MessageID, res *fspb.MessageResult) error {
    67  	return d.runInTx(ctx, false, func(tx *sql.Tx) error { return d.trySetMessageResult(ctx, tx, dest.IsNil(), id, res) })
    68  }
    69  
    70  func (d *Datastore) trySetMessageResult(ctx context.Context, tx *sql.Tx, forServer bool, id common.MessageID, res *fspb.MessageResult) error {
    71  	dbm := dbMessage{
    72  		messageID:            id.Bytes(),
    73  		processedTimeSeconds: sql.NullInt64{Valid: true, Int64: res.ProcessedTime.Seconds},
    74  		processedTimeNanos:   sql.NullInt64{Valid: true, Int64: int64(res.ProcessedTime.Nanos)},
    75  	}
    76  	if res.Failed {
    77  		dbm.failed = sql.NullBool{Valid: true, Bool: true}
    78  		dbm.failedReason = sql.NullString{Valid: true, String: res.FailedReason}
    79  	}
    80  	_, err := tx.ExecContext(ctx, "UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id=?",
    81  		dbm.failed, dbm.failedReason, dbm.processedTimeSeconds, dbm.processedTimeNanos, dbm.messageID)
    82  	if err != nil {
    83  		return err
    84  	}
    85  	var fs int
    86  	if forServer {
    87  		fs = 1
    88  	}
    89  	_, err = tx.ExecContext(ctx, "DELETE FROM pending_messages WHERE for_server=? AND message_id=?", fs, dbm.messageID)
    90  	return err
    91  }
    92  
    93  func fromNULLString(s sql.NullString) string {
    94  	if !s.Valid {
    95  		return ""
    96  	}
    97  	return s.String
    98  }
    99  
   100  func fromMessageProto(m *fspb.Message) (*dbMessage, error) {
   101  	id, err := common.BytesToMessageID(m.MessageId)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	dbm := &dbMessage{
   106  		messageID:   id.Bytes(),
   107  		messageType: m.MessageType,
   108  	}
   109  	if m.Source != nil {
   110  		dbm.sourceClientID = m.Source.ClientId
   111  		dbm.sourceServiceName = m.Source.ServiceName
   112  	}
   113  	if m.Destination != nil {
   114  		dbm.destinationClientID = m.Destination.ClientId
   115  		dbm.destinationServiceName = m.Destination.ServiceName
   116  	}
   117  	if len(m.SourceMessageId) != 0 {
   118  		dbm.sourceMessageID = m.SourceMessageId
   119  	}
   120  	if m.CreationTime != nil {
   121  		dbm.creationTimeSeconds = m.CreationTime.Seconds
   122  		dbm.creationTimeNanos = m.CreationTime.Nanos
   123  	}
   124  	if m.Result != nil {
   125  		r := m.Result
   126  		if r.ProcessedTime != nil {
   127  			dbm.processedTimeSeconds = sql.NullInt64{Int64: r.ProcessedTime.Seconds, Valid: true}
   128  			dbm.processedTimeNanos = sql.NullInt64{Int64: int64(r.ProcessedTime.Nanos), Valid: true}
   129  		}
   130  		if r.Failed {
   131  			dbm.failed = sql.NullBool{Bool: true, Valid: true}
   132  			dbm.failedReason = sql.NullString{String: r.FailedReason, Valid: true}
   133  		}
   134  	}
   135  	if m.Data != nil {
   136  		dbm.dataTypeURL = sql.NullString{String: m.Data.TypeUrl, Valid: true}
   137  		dbm.dataValue = m.Data.Value
   138  	}
   139  	if m.ValidationInfo != nil {
   140  		b, err := proto.Marshal(m.ValidationInfo)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  		dbm.validationInfo = b
   145  	}
   146  	if m.Annotations != nil {
   147  		b, err := proto.Marshal(m.Annotations)
   148  		if err != nil {
   149  			return nil, err
   150  		}
   151  		dbm.annotations = b
   152  	}
   153  	return dbm, nil
   154  }
   155  
   156  func toMessageResultProto(m *dbMessage) *fspb.MessageResult {
   157  	if !m.processedTimeSeconds.Valid {
   158  		return nil
   159  	}
   160  
   161  	ret := &fspb.MessageResult{
   162  		ProcessedTime: &tspb.Timestamp{
   163  			Seconds: m.processedTimeSeconds.Int64,
   164  			Nanos:   int32(m.processedTimeNanos.Int64)},
   165  		Failed: m.failed.Valid && m.failed.Bool,
   166  	}
   167  
   168  	if m.failedReason.Valid {
   169  		ret.FailedReason = m.failedReason.String
   170  	}
   171  	return ret
   172  }
   173  
   174  func toMessageProto(m *dbMessage) (*fspb.Message, error) {
   175  	mid, err := common.BytesToMessageID(m.messageID)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	pm := &fspb.Message{
   180  		MessageId: mid.Bytes(),
   181  		Source: &fspb.Address{
   182  			ClientId:    m.sourceClientID,
   183  			ServiceName: m.sourceServiceName,
   184  		},
   185  		SourceMessageId: m.sourceMessageID,
   186  		Destination: &fspb.Address{
   187  			ClientId:    m.destinationClientID,
   188  			ServiceName: m.destinationServiceName,
   189  		},
   190  		MessageType: m.messageType,
   191  		CreationTime: &tspb.Timestamp{
   192  			Seconds: m.creationTimeSeconds,
   193  			Nanos:   m.creationTimeNanos,
   194  		},
   195  		Result: toMessageResultProto(m),
   196  	}
   197  	if m.dataTypeURL.Valid {
   198  		pm.Data = &anypb.Any{
   199  			TypeUrl: m.dataTypeURL.String,
   200  			Value:   m.dataValue,
   201  		}
   202  	}
   203  	if len(m.validationInfo) > 0 {
   204  		v := &fspb.ValidationInfo{}
   205  		if err := proto.Unmarshal(m.validationInfo, v); err != nil {
   206  			return nil, err
   207  		}
   208  		pm.ValidationInfo = v
   209  	}
   210  	if len(m.annotations) > 0 {
   211  		a := &fspb.Annotations{}
   212  		if err := proto.Unmarshal(m.annotations, a); err != nil {
   213  			return nil, err
   214  		}
   215  		pm.Annotations = a
   216  	}
   217  	return pm, nil
   218  }
   219  
   220  func genPlaceholders(num int) string {
   221  	es := make([]string, num)
   222  	for i := range es {
   223  		es[i] = "?"
   224  	}
   225  	return strings.Join(es, ", ")
   226  }
   227  
   228  func (d *Datastore) getPendingMessageRawIds(ctx context.Context, tx *sql.Tx, ids []common.ClientID, offset uint64, limit uint64, forUpdate bool) ([][]byte, error) {
   229  	squery := fmt.Sprintf("SELECT "+
   230  		"m.message_id AS message_id "+
   231  		"FROM messages AS m, pending_messages AS pm "+
   232  		"WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id "+
   233  		"ORDER BY message_id",
   234  		genPlaceholders((len(ids))))
   235  
   236  	if offset != 0 && limit == 0 {
   237  		return nil, fmt.Errorf("if offset is provided, a limit must be provided as well")
   238  	}
   239  
   240  	if limit != 0 {
   241  		squery += " LIMIT ?"
   242  	}
   243  
   244  	if offset != 0 {
   245  		squery += " OFFSET ?"
   246  	}
   247  
   248  	if forUpdate {
   249  		squery += " FOR UPDATE"
   250  	}
   251  
   252  	args := make([]any, len(ids), len(ids)+2)
   253  	for i, v := range ids {
   254  		args[i] = v.Bytes()
   255  	}
   256  
   257  	if limit != 0 {
   258  		args = append(args, limit)
   259  	}
   260  
   261  	if offset != 0 {
   262  		args = append(args, offset)
   263  	}
   264  
   265  	idsToProc := make([][]byte, 0)
   266  
   267  	rs, err := tx.QueryContext(ctx, squery, args...)
   268  	if err != nil {
   269  		return nil, fmt.Errorf("Failed to fetch the list of pending messages: %v", err)
   270  	}
   271  	defer rs.Close()
   272  	for rs.Next() {
   273  		var id []byte
   274  		if err := rs.Scan(&id); err != nil {
   275  			return nil, err
   276  		}
   277  		idsToProc = append(idsToProc, id)
   278  	}
   279  
   280  	return idsToProc, nil
   281  }
   282  
   283  func (d *Datastore) GetPendingMessageCount(ctx context.Context, ids []common.ClientID) (uint64, error) {
   284  	var result uint64
   285  
   286  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   287  		squery := fmt.Sprintf("SELECT "+
   288  			"COUNT(*) "+
   289  			"FROM messages AS m, pending_messages AS pm "+
   290  			"WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id ",
   291  			genPlaceholders((len(ids))))
   292  
   293  		args := make([]any, len(ids))
   294  		for i, v := range ids {
   295  			args[i] = v.Bytes()
   296  		}
   297  		rs, err := tx.QueryContext(ctx, squery, args...)
   298  		if err != nil {
   299  			return fmt.Errorf("Failed to fetch the pending message count: %v", err)
   300  		}
   301  		defer rs.Close()
   302  		if !rs.Next() {
   303  			return fmt.Errorf("Got empty result")
   304  		}
   305  		err = rs.Scan(&result)
   306  		if err != nil {
   307  			return fmt.Errorf("Failed to scan result: %v", err)
   308  		}
   309  		return nil
   310  	})
   311  
   312  	return result, err
   313  }
   314  
   315  func (d *Datastore) GetPendingMessages(ctx context.Context, ids []common.ClientID, offset uint64, count uint64, wantData bool) ([]*fspb.Message, error) {
   316  	var res []*fspb.Message
   317  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   318  		messageIdsRaw, err := d.getPendingMessageRawIds(ctx, tx, ids, offset, count, false)
   319  		if err != nil {
   320  			return err
   321  		}
   322  		var messageIds []common.MessageID
   323  		for _, idRaw := range messageIdsRaw {
   324  			messageID, err := common.BytesToMessageID(idRaw)
   325  			if err != nil {
   326  				return err
   327  			}
   328  			messageIds = append(messageIds, messageID)
   329  		}
   330  		res, err = d.getMessages(ctx, tx, messageIds, wantData)
   331  		return err
   332  	})
   333  	return res, err
   334  }
   335  
   336  func (d *Datastore) DeletePendingMessages(ctx context.Context, ids []common.ClientID) error {
   337  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   338  		messageIds, err := d.getPendingMessageRawIds(ctx, tx, ids, 0, 0, true)
   339  		if err != nil {
   340  			return err
   341  		}
   342  
   343  		idsToProc := make([]any, len(messageIds))
   344  		for i, id := range messageIds {
   345  			idsToProc[i] = id
   346  		}
   347  
   348  		// If there are no messages to be deleted, just bail out.
   349  		if len(idsToProc) == 0 {
   350  			return nil
   351  		}
   352  
   353  		now := db.NowProto()
   354  		ptimeSecs := sql.NullInt64{Valid: true, Int64: now.Seconds}
   355  		ptimeNanoSecs := sql.NullInt64{Valid: true, Int64: int64(now.Nanos)}
   356  		failed := sql.NullBool{Valid: true, Bool: true}
   357  		failedReason := sql.NullString{Valid: true, String: "Removed by admin action."}
   358  
   359  		ps := genPlaceholders(len(idsToProc))
   360  		uquery := fmt.Sprintf("UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id IN (%s)", ps)
   361  		_, err = tx.ExecContext(ctx, uquery, append([]any{failed, failedReason, ptimeSecs, ptimeNanoSecs}, idsToProc...)...)
   362  		if err != nil {
   363  			return err
   364  		}
   365  
   366  		dquery := fmt.Sprintf("DELETE FROM pending_messages WHERE message_id IN (%s)", ps)
   367  		_, err = tx.ExecContext(ctx, dquery, idsToProc...)
   368  
   369  		return err
   370  	})
   371  }
   372  
   373  func (d *Datastore) StoreMessages(ctx context.Context, msgs []*fspb.Message, contact db.ContactID) error {
   374  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   375  		ids := make([][]byte, 0, len(msgs))
   376  		for _, m := range msgs {
   377  			dbm, err := fromMessageProto(m)
   378  			if err != nil {
   379  				return err
   380  			}
   381  			// If it is already processed, we don't want to save m.Data.
   382  			if m.Result != nil {
   383  				dbm.dataTypeURL = sql.NullString{Valid: false}
   384  				dbm.dataValue = nil
   385  			}
   386  			ids = append(ids, dbm.messageID)
   387  			if m.Result != nil && !m.Result.Failed {
   388  				if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil {
   389  					return err
   390  				}
   391  				if m.Result != nil {
   392  					mid, _ := common.BytesToMessageID(m.MessageId)
   393  					if err := d.trySetMessageResult(ctx, tx, m.Destination.ClientId == nil, mid, m.Result); err != nil {
   394  						return err
   395  					}
   396  				}
   397  				continue
   398  			}
   399  			var processedTime sql.NullInt64
   400  			var failed sql.NullBool
   401  			e := tx.QueryRowContext(ctx, "SELECT processed_time_seconds, failed FROM messages where message_id=?", dbm.messageID).Scan(&processedTime, &failed)
   402  			switch {
   403  			case e == sql.ErrNoRows:
   404  				// Common case. Message not yet present, store as normal.
   405  				if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil {
   406  					return err
   407  				}
   408  			case e != nil:
   409  				return e
   410  			case processedTime.Valid && (!failed.Valid || !failed.Bool):
   411  				// Message previously successfully processed, ignore this reprocessing.
   412  			case m.Result != nil && (!processedTime.Valid || !m.Result.Failed):
   413  				mid, err := common.BytesToMessageID(m.MessageId)
   414  				if err != nil {
   415  					return err
   416  				}
   417  				// Message not previously successfully processed, but this try succeeded. Mark as processed.
   418  				if err := d.trySetMessageResult(ctx, tx, m.Destination.ClientId == nil, mid, m.Result); err != nil {
   419  					return err
   420  				}
   421  			default:
   422  				// The message is already present, but unprocessed/failed, and this
   423  				// processing didn't succeed or is ongoing. Nothing to do.
   424  			}
   425  		}
   426  
   427  		if contact == "" {
   428  			return nil
   429  		}
   430  
   431  		c, err := strconv.ParseUint(string(contact), 16, 64)
   432  		if err != nil {
   433  			e := fmt.Errorf("unable to parse ContactID [%v]: %v", contact, err)
   434  			log.Error(e)
   435  			return e
   436  		}
   437  		for _, id := range ids {
   438  			if _, err := tx.ExecContext(ctx, "INSERT IGNORE INTO client_contact_messages(client_contact_id, message_id) VALUES (?, ?)", c, id); err != nil {
   439  				return err
   440  			}
   441  		}
   442  		return nil
   443  	})
   444  }
   445  
   446  func (d *Datastore) tryStoreMessage(ctx context.Context, tx *sql.Tx, dbm *dbMessage, isBroadcast bool) error {
   447  	if dbm.creationTimeSeconds == 0 {
   448  		return errors.New("message CreationTime must be set")
   449  	}
   450  	res, err := tx.ExecContext(ctx, "INSERT IGNORE INTO messages("+
   451  		"message_id, "+
   452  		"source_client_id, "+
   453  		"source_service_name, "+
   454  		"source_message_id, "+
   455  		"destination_client_id, "+
   456  		"destination_service_name, "+
   457  		"message_type, "+
   458  		"creation_time_seconds, "+
   459  		"creation_time_nanos, "+
   460  		"processed_time_seconds, "+
   461  		"processed_time_nanos, "+
   462  		"failed,"+
   463  		"failed_reason,"+
   464  		"validation_info,"+
   465  		"annotations) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
   466  		dbm.messageID,
   467  		dbm.sourceClientID,
   468  		dbm.sourceServiceName,
   469  		dbm.sourceMessageID,
   470  		dbm.destinationClientID,
   471  		dbm.destinationServiceName,
   472  		dbm.messageType,
   473  		dbm.creationTimeSeconds,
   474  		dbm.creationTimeNanos,
   475  		dbm.processedTimeSeconds,
   476  		dbm.processedTimeNanos,
   477  		dbm.failed,
   478  		dbm.failedReason,
   479  		dbm.validationInfo,
   480  		dbm.annotations)
   481  	if err != nil {
   482  		return err
   483  	}
   484  	cnt, err := res.RowsAffected()
   485  	if err != nil {
   486  		return err
   487  	}
   488  	inserted := cnt == 1
   489  	if inserted && !dbm.processedTimeSeconds.Valid {
   490  		var due int64
   491  		if dbm.destinationClientID == nil {
   492  			due = toMicro(db.ServerRetryTime(0))
   493  		} else {
   494  			// If this is being created in response to a broadcast, then we about to
   495  			// hand it to the client and should wait before providing in through
   496  			// ClientMessagesForProcessing. Otherwise, we should give it to the client
   497  			// on next contact.
   498  			if isBroadcast {
   499  				due = toMicro(db.ClientRetryTime())
   500  			} else {
   501  				due = toMicro(db.Now())
   502  			}
   503  		}
   504  		fs := 0
   505  		if dbm.destinationClientID == nil {
   506  			fs = 1
   507  		}
   508  		_, err = tx.ExecContext(ctx, "INSERT INTO pending_messages("+
   509  			"for_server, "+
   510  			"message_id, "+
   511  			"retry_count, "+
   512  			"scheduled_time, "+
   513  			"data_type_url, "+
   514  			"data_value) VALUES(?, ?, ?, ?, ?, ?)",
   515  			fs,
   516  			dbm.messageID,
   517  			0,
   518  			due,
   519  			dbm.dataTypeURL,
   520  			dbm.dataValue)
   521  		if err != nil {
   522  			return err
   523  		}
   524  	}
   525  	return nil
   526  }
   527  
   528  func (d *Datastore) getMessages(ctx context.Context, tx *sql.Tx, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) {
   529  	res := make([]*fspb.Message, 0, len(ids))
   530  	stmt1, err := tx.Prepare("SELECT " +
   531  		"message_id, " +
   532  		"source_client_id, " +
   533  		"source_service_name, " +
   534  		"source_message_id, " +
   535  		"destination_client_id, " +
   536  		"destination_service_name, " +
   537  		"message_type, " +
   538  		"creation_time_seconds, " +
   539  		"creation_time_nanos, " +
   540  		"processed_time_seconds, " +
   541  		"processed_time_nanos, " +
   542  		"validation_info, " +
   543  		"annotations " +
   544  		"FROM messages WHERE message_id=?")
   545  	var stmt2 *sql.Stmt
   546  	if wantData {
   547  		stmt2, err = tx.Prepare("SELECT data_type_url, data_value FROM pending_messages WHERE message_id=?")
   548  		if err != nil {
   549  			return nil, err
   550  		}
   551  	}
   552  	if err != nil {
   553  		return nil, err
   554  	}
   555  	for _, id := range ids {
   556  		row := stmt1.QueryRowContext(ctx, id.Bytes())
   557  		var dbm dbMessage
   558  		err := row.Scan(
   559  			&dbm.messageID,
   560  			&dbm.sourceClientID,
   561  			&dbm.sourceServiceName,
   562  			&dbm.sourceMessageID,
   563  			&dbm.destinationClientID,
   564  			&dbm.destinationServiceName,
   565  			&dbm.messageType,
   566  			&dbm.creationTimeSeconds,
   567  			&dbm.creationTimeNanos,
   568  			&dbm.processedTimeSeconds,
   569  			&dbm.processedTimeNanos,
   570  			&dbm.validationInfo,
   571  			&dbm.annotations)
   572  		if err != nil {
   573  			return nil, err
   574  		}
   575  		if wantData {
   576  			row := stmt2.QueryRowContext(ctx, id.Bytes())
   577  			err := row.Scan(&dbm.dataTypeURL, &dbm.dataValue)
   578  			if err != nil && err != sql.ErrNoRows {
   579  				return nil, err
   580  			}
   581  		}
   582  		m, err := toMessageProto(&dbm)
   583  		if err != nil {
   584  			return nil, err
   585  		}
   586  		res = append(res, m)
   587  	}
   588  	return res, nil
   589  }
   590  
   591  func (d *Datastore) GetMessages(ctx context.Context, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) {
   592  	var res []*fspb.Message
   593  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   594  		var err error
   595  		res, err = d.getMessages(ctx, tx, ids, wantData)
   596  		return err
   597  	})
   598  	return res, err
   599  }
   600  
   601  func (d *Datastore) GetMessageResult(ctx context.Context, id common.MessageID) (*fspb.MessageResult, error) {
   602  	var ret *fspb.MessageResult
   603  
   604  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   605  		row := tx.QueryRowContext(ctx, "SELECT "+
   606  			"creation_time_seconds, "+
   607  			"creation_time_nanos, "+
   608  			"processed_time_seconds, "+
   609  			"processed_time_nanos, "+
   610  			"failed, "+
   611  			"failed_reason "+
   612  			"FROM messages WHERE message_id=?", id.Bytes())
   613  
   614  		var dbm dbMessage
   615  		if err := row.Scan(
   616  			&dbm.creationTimeSeconds,
   617  			&dbm.creationTimeNanos,
   618  			&dbm.processedTimeSeconds,
   619  			&dbm.processedTimeNanos,
   620  			&dbm.failed,
   621  			&dbm.failedReason,
   622  		); err == sql.ErrNoRows {
   623  			return nil
   624  		} else if err != nil {
   625  			return err
   626  		}
   627  
   628  		ret = toMessageResultProto(&dbm)
   629  		return nil
   630  	})
   631  
   632  	return ret, err
   633  }
   634  
   635  func (d *Datastore) ClientMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) {
   636  	if id == (common.ClientID{}) {
   637  		return nil, errors.New("a client is required")
   638  	}
   639  	return d.internalClientMessagesForProcessing(ctx, id, lim, serviceLimits)
   640  }
   641  
   642  type pendingUpdate struct {
   643  	id  []byte
   644  	nc  uint32
   645  	due int64
   646  }
   647  
   648  func (d *Datastore) internalClientMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) {
   649  
   650  	var read map[string]uint64
   651  	if serviceLimits != nil {
   652  		read = make(map[string]uint64)
   653  	}
   654  	res := make([]*fspb.Message, 0, lim)
   655  
   656  	if err := d.runInTx(ctx, false, func(tx *sql.Tx) error {
   657  		var updates []*pendingUpdate
   658  
   659  		// As an internal addition to the MessageStore interface, this
   660  		// also gets server messages when id=ClientID{}.
   661  		rs, err := tx.QueryContext(ctx, "SELECT "+
   662  			"m.message_id, "+
   663  			"m.source_client_id, "+
   664  			"m.source_service_name, "+
   665  			"m.source_message_id, "+
   666  			"m.destination_client_id, "+
   667  			"m.destination_service_name, "+
   668  			"m.message_type, "+
   669  			"m.creation_time_seconds, "+
   670  			"m.creation_time_nanos,"+
   671  			"m.validation_info,"+
   672  			"m.annotations,"+
   673  			"pm.retry_count, "+
   674  			"pm.data_type_url, "+
   675  			"pm.data_value "+
   676  			"FROM messages AS m, pending_messages AS pm "+
   677  			"WHERE m.destination_client_id = ? AND m.message_id=pm.message_id AND pm.scheduled_time < ? "+
   678  			"FOR UPDATE",
   679  			id.Bytes(), toMicro(db.Now()))
   680  		if err != nil {
   681  			return err
   682  		}
   683  		defer rs.Close()
   684  		for rs.Next() {
   685  			var dbm dbMessage
   686  			if err := rs.Scan(
   687  				&dbm.messageID,
   688  				&dbm.sourceClientID,
   689  				&dbm.sourceServiceName,
   690  				&dbm.sourceMessageID,
   691  				&dbm.destinationClientID,
   692  				&dbm.destinationServiceName,
   693  				&dbm.messageType,
   694  				&dbm.creationTimeSeconds,
   695  				&dbm.creationTimeNanos,
   696  				&dbm.validationInfo,
   697  				&dbm.annotations,
   698  				&dbm.retryCount,
   699  				&dbm.dataTypeURL,
   700  				&dbm.dataValue,
   701  			); err != nil {
   702  				return err
   703  			}
   704  			if serviceLimits != nil {
   705  				if read[dbm.destinationServiceName] >= serviceLimits[dbm.destinationServiceName] {
   706  					continue
   707  				} else {
   708  					read[dbm.destinationServiceName]++
   709  				}
   710  			}
   711  
   712  			if len(res) >= int(lim) {
   713  				continue
   714  			}
   715  
   716  			updates = append(updates, &pendingUpdate{
   717  				id:  dbm.messageID,
   718  				nc:  dbm.retryCount + 1,
   719  				due: toMicro(db.ClientRetryTime())})
   720  			m, err := toMessageProto(&dbm)
   721  			if err != nil {
   722  				return err
   723  			}
   724  			res = append(res, m)
   725  		}
   726  		if err := rs.Err(); err != nil {
   727  			return err
   728  		}
   729  		for _, u := range updates {
   730  			if _, err := tx.ExecContext(ctx, "UPDATE pending_messages SET retry_count=?, scheduled_time=? WHERE for_server=0 AND message_id=?", u.nc, u.due, u.id); err != nil {
   731  				return err
   732  			}
   733  		}
   734  		return nil
   735  	}); err != nil {
   736  		return nil, err
   737  	}
   738  	return res, nil
   739  }
   740  
   741  func (d *Datastore) internalServerMessagesForProcessing(ctx context.Context, lim int) ([]*fspb.Message, error) {
   742  
   743  	var dbms []*dbMessage
   744  	// First a full transaction to read and update the pending_messages table.
   745  	if err := d.runInTx(ctx, false, func(tx *sql.Tx) error {
   746  
   747  		// Use two random bytes to decide where to start reading (and locking). This
   748  		// reduces the odds of trying to grab the same records as another server
   749  		// when there is a large backlog.
   750  		b := [2]byte{}
   751  		rand.Read(b[:])
   752  		start := hex.EncodeToString(b[:])
   753  
   754  		readPending := func(q string, l int) error {
   755  			rs, err := tx.QueryContext(ctx, q, start, toMicro(db.Now()), l)
   756  			if err != nil {
   757  				return err
   758  			}
   759  			defer rs.Close()
   760  			for rs.Next() {
   761  				var dbm dbMessage
   762  				if err := rs.Scan(
   763  					&dbm.messageID,
   764  					&dbm.retryCount,
   765  					&dbm.dataTypeURL,
   766  					&dbm.dataValue,
   767  				); err != nil {
   768  					return err
   769  				}
   770  				dbms = append(dbms, &dbm)
   771  			}
   772  			return rs.Err()
   773  		}
   774  
   775  		// First query reads from "start" to keyspace end.
   776  		if err := readPending(
   777  			"SELECT "+
   778  				"message_id, "+
   779  				"retry_count, "+
   780  				"data_type_url, "+
   781  				"data_value "+
   782  				"FROM pending_messages "+
   783  				"WHERE for_server=1 "+
   784  				"AND message_id > ? "+
   785  				"AND scheduled_time < ? "+
   786  				"ORDER BY message_id "+
   787  				"LIMIT ? "+
   788  				"FOR UPDATE", lim); err != nil {
   789  			return err
   790  		}
   791  		// If needed, second reads from keyspace begin to "start".
   792  		if len(dbms) < lim {
   793  			if err := readPending(
   794  				"SELECT "+
   795  					"message_id, "+
   796  					"retry_count, "+
   797  					"data_type_url, "+
   798  					"data_value "+
   799  					"FROM pending_messages "+
   800  					"WHERE for_server=1 "+
   801  					"AND message_id < ? "+
   802  					"AND scheduled_time < ? "+
   803  					"ORDER BY message_id "+
   804  					"LIMIT ? "+
   805  					"FOR UPDATE", lim-len(dbms)); err != nil {
   806  				return err
   807  			}
   808  		}
   809  		stmt, err := tx.PrepareContext(ctx, "UPDATE pending_messages SET retry_count=?, scheduled_time=? WHERE for_server=1 AND message_id=?")
   810  		if err != nil {
   811  			return err
   812  		}
   813  		for _, dbm := range dbms {
   814  			if _, err := stmt.ExecContext(ctx, dbm.retryCount+1, toMicro(db.ServerRetryTime(dbm.retryCount+1)), dbm.messageID); err != nil {
   815  				return err
   816  			}
   817  		}
   818  		return nil
   819  	}); err != nil {
   820  		return nil, err
   821  	}
   822  	if dbms == nil {
   823  		return nil, nil
   824  	}
   825  	// Now we read the rest of the message metadata. In a perfect world, we'd
   826  	// rollback the transaction above if anything here fails, however we do this
   827  	// outside of the transaction in order to release any locks held by the
   828  	// transaction asap.
   829  	res := make([]*fspb.Message, 0, len(dbms))
   830  	stmt, err := d.db.PrepareContext(ctx, "SELECT "+
   831  		"source_client_id, "+
   832  		"source_service_name, "+
   833  		"source_message_id, "+
   834  		"destination_client_id, "+
   835  		"destination_service_name, "+
   836  		"message_type, "+
   837  		"creation_time_seconds, "+
   838  		"creation_time_nanos, "+
   839  		"validation_info, "+
   840  		"annotations "+
   841  		"FROM messages "+
   842  		"WHERE message_id = ?")
   843  	if err != nil {
   844  		return nil, err
   845  	}
   846  	defer stmt.Close()
   847  	for _, dbm := range dbms {
   848  		r := stmt.QueryRowContext(ctx, dbm.messageID)
   849  		if err := r.Scan(
   850  			&dbm.sourceClientID,
   851  			&dbm.sourceServiceName,
   852  			&dbm.sourceMessageID,
   853  			&dbm.destinationClientID,
   854  			&dbm.destinationServiceName,
   855  			&dbm.messageType,
   856  			&dbm.creationTimeSeconds,
   857  			&dbm.creationTimeNanos,
   858  			&dbm.validationInfo,
   859  			&dbm.annotations,
   860  		); err != nil {
   861  			return nil, err
   862  		}
   863  		m, err := toMessageProto(dbm)
   864  		if err != nil {
   865  			return nil, err
   866  		}
   867  		res = append(res, m)
   868  	}
   869  	return res, nil
   870  }
   871  
   872  type messageLooper struct {
   873  	d *Datastore
   874  
   875  	mp               db.MessageProcessor
   876  	processingTicker *time.Ticker
   877  	stopCalled       chan struct{}
   878  	loopDone         chan struct{}
   879  }
   880  
   881  func (d *Datastore) RegisterMessageProcessor(mp db.MessageProcessor) {
   882  	if d.looper != nil {
   883  		log.Warning("Attempt to register a second MessageProcessor.")
   884  		d.looper.stop()
   885  	}
   886  	d.looper = &messageLooper{
   887  		d:                d,
   888  		mp:               mp,
   889  		processingTicker: time.NewTicker(300 * time.Millisecond),
   890  		stopCalled:       make(chan struct{}),
   891  		loopDone:         make(chan struct{}),
   892  	}
   893  	go d.looper.messageProcessingLoop()
   894  }
   895  
   896  func (d *Datastore) StopMessageProcessor() {
   897  	if d.looper != nil {
   898  		d.looper.stop()
   899  	}
   900  	d.looper = nil
   901  }
   902  
   903  // messageProcessingLoop reads messages that should be processed on the server
   904  // from the datastore and delivers them to the registered MessageProcessor.
   905  func (l *messageLooper) messageProcessingLoop() {
   906  	defer close(l.loopDone)
   907  	for {
   908  		select {
   909  		case <-l.stopCalled:
   910  			return
   911  		case <-l.processingTicker.C:
   912  			l.processMessages()
   913  		}
   914  	}
   915  }
   916  
   917  func (l *messageLooper) stop() {
   918  	l.processingTicker.Stop()
   919  	close(l.stopCalled)
   920  	<-l.loopDone
   921  }
   922  
   923  func (l *messageLooper) processMessages() {
   924  	for {
   925  		msgs, err := l.d.internalServerMessagesForProcessing(context.Background(), 10)
   926  		if err != nil {
   927  			log.Errorf("Failed to read server messages for processing: %v", err)
   928  			continue
   929  		}
   930  		l.mp.ProcessMessages(msgs)
   931  		if len(msgs) == 0 {
   932  			return
   933  		}
   934  	}
   935  }