github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/sqlite/messagestore.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlite
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"encoding/hex"
    21  	"errors"
    22  	"fmt"
    23  	"strconv"
    24  	"strings"
    25  	"time"
    26  
    27  	log "github.com/golang/glog"
    28  
    29  	"github.com/google/fleetspeak/fleetspeak/src/common"
    30  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    31  
    32  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    33  	"google.golang.org/protobuf/proto"
    34  	anypb "google.golang.org/protobuf/types/known/anypb"
    35  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    36  )
    37  
    38  // dbMessage matches the schema of the messages table, optionally joined to the
    39  // pending_messages table.
    40  type dbMessage struct {
    41  	messageID              string
    42  	sourceClientID         string
    43  	sourceServiceName      string
    44  	sourceMessageID        string
    45  	destinationClientID    string
    46  	destinationServiceName string
    47  	messageType            string
    48  	creationTimeSeconds    int64
    49  	creationTimeNanos      int32
    50  	processedTimeSeconds   sql.NullInt64
    51  	processedTimeNanos     sql.NullInt64
    52  	validationInfo         []byte
    53  	failed                 sql.NullBool
    54  	failedReason           sql.NullString
    55  	retryCount             uint32
    56  	dataTypeURL            sql.NullString
    57  	dataValue              []byte
    58  	annotations            []byte
    59  }
    60  
    61  func toMicro(t time.Time) int64 {
    62  	return t.UnixNano() / 1000
    63  }
    64  
    65  func (d *Datastore) SetMessageResult(ctx context.Context, dest common.ClientID, id common.MessageID, res *fspb.MessageResult) error {
    66  	d.l.Lock()
    67  	defer d.l.Unlock()
    68  	return d.runInTx(func(tx *sql.Tx) error { return d.trySetMessageResult(ctx, tx, id, res) })
    69  }
    70  
    71  func (d *Datastore) trySetMessageResult(ctx context.Context, tx *sql.Tx, id common.MessageID, res *fspb.MessageResult) error {
    72  	dbm := dbMessage{
    73  		messageID:            id.String(),
    74  		processedTimeSeconds: sql.NullInt64{Valid: true, Int64: res.ProcessedTime.Seconds},
    75  		processedTimeNanos:   sql.NullInt64{Valid: true, Int64: int64(res.ProcessedTime.Nanos)},
    76  	}
    77  	if res.Failed {
    78  		dbm.failed = sql.NullBool{Valid: true, Bool: true}
    79  		dbm.failedReason = sql.NullString{Valid: true, String: res.FailedReason}
    80  	}
    81  	_, err := tx.ExecContext(ctx, "UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id=?",
    82  		dbm.failed, dbm.failedReason, dbm.processedTimeSeconds, dbm.processedTimeNanos, dbm.messageID)
    83  	if err != nil {
    84  		return err
    85  	}
    86  	_, err = tx.ExecContext(ctx, "DELETE FROM pending_messages WHERE message_id=?", dbm.messageID)
    87  	return err
    88  }
    89  
    90  func toClientIDString(b []byte) string {
    91  	if len(b) == 0 {
    92  		return ""
    93  	}
    94  	id, err := common.BytesToClientID(b)
    95  	if err != nil {
    96  		log.Fatalf("Could't parse ClientID(%v): %v", b, err)
    97  	}
    98  	return id.String()
    99  }
   100  
   101  func fromClientIDString(s string) (b []byte) {
   102  	if s == "" {
   103  		return nil
   104  	}
   105  	cid, err := common.StringToClientID(s)
   106  	if err != nil {
   107  		log.Fatalf("Couldn't parse ClientID(%v): %v", s, err)
   108  		return nil
   109  	}
   110  	return cid.Bytes()
   111  }
   112  
   113  func fromNULLString(s sql.NullString) string {
   114  	if !s.Valid {
   115  		return ""
   116  	}
   117  	return s.String
   118  }
   119  
   120  func fromMessageProto(m *fspb.Message) (*dbMessage, error) {
   121  	id, err := common.BytesToMessageID(m.MessageId)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	dbm := &dbMessage{
   126  		messageID:   id.String(),
   127  		messageType: m.MessageType,
   128  	}
   129  	if m.Source != nil {
   130  		dbm.sourceClientID = toClientIDString(m.Source.ClientId)
   131  		dbm.sourceServiceName = m.Source.ServiceName
   132  	}
   133  	if m.Destination != nil {
   134  		dbm.destinationClientID = toClientIDString(m.Destination.ClientId)
   135  		dbm.destinationServiceName = m.Destination.ServiceName
   136  	}
   137  	if len(m.SourceMessageId) != 0 {
   138  		dbm.sourceMessageID = hex.EncodeToString(m.SourceMessageId)
   139  	}
   140  	if m.CreationTime != nil {
   141  		dbm.creationTimeSeconds = m.CreationTime.Seconds
   142  		dbm.creationTimeNanos = m.CreationTime.Nanos
   143  	}
   144  	if m.Result != nil {
   145  		r := m.Result
   146  		if r.ProcessedTime != nil {
   147  			dbm.processedTimeSeconds = sql.NullInt64{Int64: r.ProcessedTime.Seconds, Valid: true}
   148  			dbm.processedTimeNanos = sql.NullInt64{Int64: int64(r.ProcessedTime.Nanos), Valid: true}
   149  		}
   150  		if r.Failed {
   151  			dbm.failed = sql.NullBool{Bool: true, Valid: true}
   152  			dbm.failedReason = sql.NullString{String: r.FailedReason, Valid: true}
   153  		}
   154  	}
   155  	if m.Data != nil {
   156  		dbm.dataTypeURL = sql.NullString{String: m.Data.TypeUrl, Valid: true}
   157  		dbm.dataValue = m.Data.Value
   158  	}
   159  	if m.ValidationInfo != nil {
   160  		b, err := proto.Marshal(m.ValidationInfo)
   161  		if err != nil {
   162  			return nil, err
   163  		}
   164  		dbm.validationInfo = b
   165  	}
   166  	if m.Annotations != nil {
   167  		b, err := proto.Marshal(m.Annotations)
   168  		if err != nil {
   169  			return nil, err
   170  		}
   171  		dbm.annotations = b
   172  	}
   173  	return dbm, nil
   174  }
   175  
   176  func toMessageResultProto(m *dbMessage) *fspb.MessageResult {
   177  	if !m.processedTimeSeconds.Valid {
   178  		return nil
   179  	}
   180  
   181  	ret := &fspb.MessageResult{
   182  		ProcessedTime: &tspb.Timestamp{
   183  			Seconds: m.processedTimeSeconds.Int64,
   184  			Nanos:   int32(m.processedTimeNanos.Int64)},
   185  		Failed: m.failed.Valid && m.failed.Bool,
   186  	}
   187  
   188  	if m.failedReason.Valid {
   189  		ret.FailedReason = m.failedReason.String
   190  	}
   191  	return ret
   192  }
   193  
   194  func toMessageProto(m *dbMessage) (*fspb.Message, error) {
   195  	mid, err := common.StringToMessageID(m.messageID)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	bsmid, err := hex.DecodeString(m.sourceMessageID)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  	pm := &fspb.Message{
   204  		MessageId: mid.Bytes(),
   205  		Source: &fspb.Address{
   206  			ClientId:    fromClientIDString(m.sourceClientID),
   207  			ServiceName: m.sourceServiceName,
   208  		},
   209  		SourceMessageId: bsmid,
   210  		Destination: &fspb.Address{
   211  			ClientId:    fromClientIDString(m.destinationClientID),
   212  			ServiceName: m.destinationServiceName,
   213  		},
   214  		MessageType: m.messageType,
   215  		CreationTime: &tspb.Timestamp{
   216  			Seconds: m.creationTimeSeconds,
   217  			Nanos:   m.creationTimeNanos,
   218  		},
   219  		Result: toMessageResultProto(m),
   220  	}
   221  	if m.dataTypeURL.Valid {
   222  		pm.Data = &anypb.Any{
   223  			TypeUrl: m.dataTypeURL.String,
   224  			Value:   m.dataValue,
   225  		}
   226  	}
   227  	if len(m.validationInfo) > 0 {
   228  		v := &fspb.ValidationInfo{}
   229  		if err := proto.Unmarshal(m.validationInfo, v); err != nil {
   230  			return nil, err
   231  		}
   232  		pm.ValidationInfo = v
   233  	}
   234  	if len(m.annotations) > 0 {
   235  		a := &fspb.Annotations{}
   236  		if err := proto.Unmarshal(m.annotations, a); err != nil {
   237  			return nil, err
   238  		}
   239  		pm.Annotations = a
   240  	}
   241  	return pm, nil
   242  }
   243  
   244  func (d *Datastore) StoreMessages(ctx context.Context, msgs []*fspb.Message, contact db.ContactID) error {
   245  	d.l.Lock()
   246  	defer d.l.Unlock()
   247  
   248  	ids := make([]string, 0, len(msgs))
   249  
   250  	return d.runInTx(func(tx *sql.Tx) error {
   251  		for _, m := range msgs {
   252  			dbm, err := fromMessageProto(m)
   253  			if err != nil {
   254  				return err
   255  			}
   256  			// If it is already processed, we don't want to save m.Data.
   257  			if m.Result != nil {
   258  				dbm.dataTypeURL = sql.NullString{Valid: false}
   259  				dbm.dataValue = nil
   260  			}
   261  			ids = append(ids, dbm.messageID)
   262  			if m.Result != nil && !m.Result.Failed {
   263  				if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil {
   264  					return err
   265  				}
   266  				if m.Result != nil {
   267  					mid, _ := common.BytesToMessageID(m.MessageId)
   268  					if err := d.trySetMessageResult(ctx, tx, mid, m.Result); err != nil {
   269  						return err
   270  					}
   271  				}
   272  				continue
   273  			}
   274  			var processedTime sql.NullInt64
   275  			var failed sql.NullBool
   276  			e := tx.QueryRowContext(ctx, "SELECT processed_time_seconds, failed FROM messages where message_id=?", dbm.messageID).Scan(&processedTime, &failed)
   277  			switch {
   278  			case e == sql.ErrNoRows:
   279  				// Common case. Message not yet present, store as normal.
   280  				if err := d.tryStoreMessage(ctx, tx, dbm, false); err != nil {
   281  					return err
   282  				}
   283  			case e != nil:
   284  				return e
   285  			case processedTime.Valid && (!failed.Valid || !failed.Bool):
   286  				// Message previously successfully processed, ignore this reprocessing.
   287  			case m.Result != nil && (!processedTime.Valid || !m.Result.Failed):
   288  				mid, err := common.BytesToMessageID(m.MessageId)
   289  				if err != nil {
   290  					return err
   291  				}
   292  				// Message not previously successfully processed, but this try succeeded. Mark as processed.
   293  				if err := d.trySetMessageResult(ctx, tx, mid, m.Result); err != nil {
   294  					return err
   295  				}
   296  			default:
   297  				// The message is already present, but unprocessed/failed, and this
   298  				// processing didn't succeed or is ongoing. Nothing to do.
   299  			}
   300  		}
   301  
   302  		if contact == "" {
   303  			return nil
   304  		}
   305  
   306  		c, err := strconv.ParseUint(string(contact), 16, 64)
   307  		if err != nil {
   308  			e := fmt.Errorf("unable to parse ContactID [%v]: %v", contact, err)
   309  			log.Error(e)
   310  			return e
   311  		}
   312  		for _, id := range ids {
   313  			if _, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO client_contact_messages(client_contact_id, message_id) VALUES (?, ?)", c, id); err != nil {
   314  				return err
   315  			}
   316  		}
   317  		return nil
   318  	})
   319  }
   320  
   321  func (d *Datastore) tryStoreMessage(ctx context.Context, tx *sql.Tx, dbm *dbMessage, isBroadcast bool) error {
   322  	if dbm.creationTimeSeconds == 0 {
   323  		return errors.New("message CreationTime must be set")
   324  	}
   325  	res, err := tx.ExecContext(ctx, "INSERT OR IGNORE INTO messages("+
   326  		"message_id, "+
   327  		"source_client_id, "+
   328  		"source_service_name, "+
   329  		"source_message_id, "+
   330  		"destination_client_id, "+
   331  		"destination_service_name, "+
   332  		"message_type, "+
   333  		"creation_time_seconds, "+
   334  		"creation_time_nanos, "+
   335  		"processed_time_seconds, "+
   336  		"processed_time_nanos, "+
   337  		"failed,"+
   338  		"failed_reason,"+
   339  		"validation_info,"+
   340  		"annotations) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
   341  		dbm.messageID,
   342  		dbm.sourceClientID,
   343  		dbm.sourceServiceName,
   344  		dbm.sourceMessageID,
   345  		dbm.destinationClientID,
   346  		dbm.destinationServiceName,
   347  		dbm.messageType,
   348  		dbm.creationTimeSeconds,
   349  		dbm.creationTimeNanos,
   350  		dbm.processedTimeSeconds,
   351  		dbm.processedTimeNanos,
   352  		dbm.failed,
   353  		dbm.failedReason,
   354  		dbm.validationInfo,
   355  		dbm.annotations)
   356  	if err != nil {
   357  		return err
   358  	}
   359  	cnt, err := res.RowsAffected()
   360  	if err != nil {
   361  		return err
   362  	}
   363  	inserted := cnt == 1
   364  	if inserted && !dbm.processedTimeSeconds.Valid {
   365  		var due int64
   366  		if dbm.destinationClientID == "" {
   367  			due = toMicro(db.ServerRetryTime(0))
   368  		} else {
   369  			// If this is being created in response to a broadcast, then we about to
   370  			// hand it to the client and should wait before providing in through
   371  			// ClientMessagesForProcessing. Otherwise, we should give it to the client
   372  			// on next contact.
   373  			if isBroadcast {
   374  				due = toMicro(db.ClientRetryTime())
   375  			} else {
   376  				due = toMicro(db.Now())
   377  			}
   378  		}
   379  		_, err = tx.ExecContext(ctx, "INSERT INTO pending_messages("+
   380  			"message_id, "+
   381  			"retry_count, "+
   382  			"scheduled_time, "+
   383  			"data_type_url, "+
   384  			"data_value) VALUES(?, ?, ?, ?, ?)",
   385  			dbm.messageID,
   386  			0,
   387  			due,
   388  			dbm.dataTypeURL,
   389  			dbm.dataValue)
   390  		if err != nil {
   391  			return err
   392  		}
   393  	}
   394  	return nil
   395  }
   396  
   397  func genPlaceholders(num int) string {
   398  	es := make([]string, num)
   399  	for i := range es {
   400  		es[i] = "?"
   401  	}
   402  	return strings.Join(es, ", ")
   403  }
   404  
   405  func (d *Datastore) getPendingMessageRawIds(ctx context.Context, tx *sql.Tx, ids []common.ClientID, offset uint64, limit uint64) ([]string, error) {
   406  	squery := fmt.Sprintf("SELECT "+
   407  		"m.message_id AS message_id "+
   408  		"FROM messages AS m, pending_messages AS pm "+
   409  		"WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id "+
   410  		"ORDER BY message_id ",
   411  		genPlaceholders((len(ids))))
   412  
   413  	if offset != 0 && limit == 0 {
   414  		return nil, fmt.Errorf("if offset is provided, a limit must be provided as well")
   415  	}
   416  
   417  	if limit != 0 {
   418  		squery += " LIMIT ?"
   419  	}
   420  
   421  	if offset != 0 {
   422  		squery += " OFFSET ?"
   423  	}
   424  
   425  	args := make([]any, len(ids), len(ids)+2)
   426  	for i, v := range ids {
   427  		args[i] = v.String()
   428  	}
   429  
   430  	if limit != 0 {
   431  		args = append(args, limit)
   432  	}
   433  
   434  	if offset != 0 {
   435  		args = append(args, offset)
   436  	}
   437  
   438  	idsToProc := make([]string, 0)
   439  
   440  	rs, err := tx.QueryContext(ctx, squery, args...)
   441  	if err != nil {
   442  		return nil, fmt.Errorf("Failed to fetch the list of messages to delete: %v", err)
   443  	}
   444  	defer rs.Close()
   445  	for rs.Next() {
   446  		var id string
   447  		if err := rs.Scan(&id); err != nil {
   448  			return nil, err
   449  		}
   450  		idsToProc = append(idsToProc, id)
   451  	}
   452  
   453  	return idsToProc, nil
   454  }
   455  
   456  func (d *Datastore) GetPendingMessageCount(ctx context.Context, ids []common.ClientID) (uint64, error) {
   457  	var result uint64
   458  
   459  	err := d.runInTx(func(tx *sql.Tx) error {
   460  		squery := fmt.Sprintf("SELECT "+
   461  			"COUNT(*) "+
   462  			"FROM messages AS m, pending_messages AS pm "+
   463  			"WHERE m.destination_client_id IN (%s) AND m.message_id=pm.message_id ",
   464  			genPlaceholders((len(ids))))
   465  
   466  		args := make([]any, len(ids))
   467  		for i, v := range ids {
   468  			args[i] = v.String()
   469  		}
   470  		rs, err := tx.QueryContext(ctx, squery, args...)
   471  		if err != nil {
   472  			return fmt.Errorf("Failed to fetch the pending message count: %v", err)
   473  		}
   474  		defer rs.Close()
   475  		if !rs.Next() {
   476  			return fmt.Errorf("Got empty result")
   477  		}
   478  		err = rs.Scan(&result)
   479  		if err != nil {
   480  			return fmt.Errorf("Failed to scan result: %v", err)
   481  		}
   482  		return nil
   483  	})
   484  
   485  	return result, err
   486  }
   487  
   488  func (d *Datastore) GetPendingMessages(ctx context.Context, ids []common.ClientID, offset uint64, limit uint64, wantData bool) ([]*fspb.Message, error) {
   489  	var res []*fspb.Message
   490  	err := d.runInTx(func(tx *sql.Tx) error {
   491  		messageIdsRaw, err := d.getPendingMessageRawIds(ctx, tx, ids, offset, limit)
   492  		if err != nil {
   493  			return err
   494  		}
   495  		var messageIds []common.MessageID
   496  		for _, idRaw := range messageIdsRaw {
   497  			messageID, err := common.StringToMessageID(idRaw)
   498  			if err != nil {
   499  				return err
   500  			}
   501  			messageIds = append(messageIds, messageID)
   502  		}
   503  		res, err = d.getMessages(ctx, tx, messageIds, wantData)
   504  		return err
   505  	})
   506  	return res, err
   507  }
   508  
   509  func (d *Datastore) DeletePendingMessages(ctx context.Context, ids []common.ClientID) error {
   510  	return d.runInTx(func(tx *sql.Tx) error {
   511  		messageIds, err := d.getPendingMessageRawIds(ctx, tx, ids, 0, 0)
   512  		if err != nil {
   513  			return err
   514  		}
   515  
   516  		idsToProc := make([]any, len(messageIds))
   517  		for i, id := range messageIds {
   518  			idsToProc[i] = id
   519  		}
   520  
   521  		// If there are no messages to be deleted, just bail out.
   522  		if len(idsToProc) == 0 {
   523  			return nil
   524  		}
   525  
   526  		now := db.NowProto()
   527  		ptimeSecs := sql.NullInt64{Valid: true, Int64: now.Seconds}
   528  		ptimeNanoSecs := sql.NullInt64{Valid: true, Int64: int64(now.Nanos)}
   529  		failed := sql.NullBool{Valid: true, Bool: true}
   530  		failedReason := sql.NullString{Valid: true, String: "Removed by admin action."}
   531  
   532  		ps := genPlaceholders(len(idsToProc))
   533  		uquery := fmt.Sprintf("UPDATE messages SET failed=?, failed_reason=?, processed_time_seconds=?, processed_time_nanos=? WHERE message_id IN (%s)", ps)
   534  		_, err = tx.ExecContext(ctx, uquery, append([]any{failed, failedReason, ptimeSecs, ptimeNanoSecs}, idsToProc...)...)
   535  		if err != nil {
   536  			return err
   537  		}
   538  
   539  		dquery := fmt.Sprintf("DELETE FROM pending_messages WHERE message_id IN (%s)", ps)
   540  		_, err = tx.ExecContext(ctx, dquery, idsToProc...)
   541  
   542  		return err
   543  	})
   544  }
   545  
   546  func (d *Datastore) getMessages(ctx context.Context, tx *sql.Tx, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) {
   547  	d.l.Lock()
   548  	defer d.l.Unlock()
   549  	res := make([]*fspb.Message, 0, len(ids))
   550  
   551  	stmt1, err := tx.Prepare("SELECT " +
   552  		"message_id, " +
   553  		"source_client_id, " +
   554  		"source_service_name, " +
   555  		"source_message_id, " +
   556  		"destination_client_id, " +
   557  		"destination_service_name, " +
   558  		"message_type, " +
   559  		"creation_time_seconds, " +
   560  		"creation_time_nanos, " +
   561  		"processed_time_seconds, " +
   562  		"processed_time_nanos, " +
   563  		"validation_info, " +
   564  		"annotations " +
   565  		"FROM messages WHERE message_id=?")
   566  	var stmt2 *sql.Stmt
   567  	if wantData {
   568  		stmt2, err = tx.Prepare("SELECT data_type_url, data_value FROM pending_messages WHERE message_id=?")
   569  		if err != nil {
   570  			return nil, err
   571  		}
   572  	}
   573  	if err != nil {
   574  		return nil, err
   575  	}
   576  	for _, id := range ids {
   577  		row := stmt1.QueryRowContext(ctx, id.String())
   578  		var dbm dbMessage
   579  		err := row.Scan(
   580  			&dbm.messageID,
   581  			&dbm.sourceClientID,
   582  			&dbm.sourceServiceName,
   583  			&dbm.sourceMessageID,
   584  			&dbm.destinationClientID,
   585  			&dbm.destinationServiceName,
   586  			&dbm.messageType,
   587  			&dbm.creationTimeSeconds,
   588  			&dbm.creationTimeNanos,
   589  			&dbm.processedTimeSeconds,
   590  			&dbm.processedTimeNanos,
   591  			&dbm.validationInfo,
   592  			&dbm.annotations)
   593  		if err != nil {
   594  			return nil, err
   595  		}
   596  		if wantData {
   597  			row := stmt2.QueryRowContext(ctx, id.String())
   598  			err := row.Scan(&dbm.dataTypeURL, &dbm.dataValue)
   599  			if err != nil && err != sql.ErrNoRows {
   600  				return nil, err
   601  			}
   602  		}
   603  		m, err := toMessageProto(&dbm)
   604  		if err != nil {
   605  			return nil, err
   606  		}
   607  		res = append(res, m)
   608  	}
   609  
   610  	return res, nil
   611  }
   612  
   613  func (d *Datastore) GetMessages(ctx context.Context, ids []common.MessageID, wantData bool) ([]*fspb.Message, error) {
   614  	var res []*fspb.Message
   615  	err := d.runInTx(func(tx *sql.Tx) error {
   616  		var err error
   617  		res, err = d.getMessages(ctx, tx, ids, wantData)
   618  		return err
   619  	})
   620  	return res, err
   621  }
   622  
   623  func (d *Datastore) GetMessageResult(ctx context.Context, id common.MessageID) (*fspb.MessageResult, error) {
   624  	d.l.Lock()
   625  	defer d.l.Unlock()
   626  
   627  	var ret *fspb.MessageResult
   628  
   629  	err := d.runInTx(func(tx *sql.Tx) error {
   630  		row := tx.QueryRowContext(ctx, "SELECT "+
   631  			"creation_time_seconds, "+
   632  			"creation_time_nanos, "+
   633  			"processed_time_seconds, "+
   634  			"processed_time_nanos, "+
   635  			"failed, "+
   636  			"failed_reason "+
   637  			"FROM messages WHERE message_id=?", id.String())
   638  
   639  		var dbm dbMessage
   640  		if err := row.Scan(
   641  			&dbm.creationTimeSeconds,
   642  			&dbm.creationTimeNanos,
   643  			&dbm.processedTimeSeconds,
   644  			&dbm.processedTimeNanos,
   645  			&dbm.failed,
   646  			&dbm.failedReason,
   647  		); err == sql.ErrNoRows {
   648  			return nil
   649  		} else if err != nil {
   650  			return err
   651  		}
   652  
   653  		ret = toMessageResultProto(&dbm)
   654  		return nil
   655  	})
   656  
   657  	return ret, err
   658  }
   659  
   660  // ClientMessagesForProcessing implements db.MessageStore.
   661  func (d *Datastore) ClientMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) {
   662  	if id == (common.ClientID{}) {
   663  		return nil, errors.New("a client is required")
   664  	}
   665  	return d.internalMessagesForProcessing(ctx, id, lim, serviceLimits)
   666  }
   667  
   668  func (d *Datastore) internalMessagesForProcessing(ctx context.Context, id common.ClientID, lim uint64, serviceLimits map[string]uint64) ([]*fspb.Message, error) {
   669  	d.l.Lock()
   670  	defer d.l.Unlock()
   671  
   672  	read := make(map[string]uint64)
   673  
   674  	var res []*fspb.Message
   675  
   676  	if err := d.runInTx(func(tx *sql.Tx) error {
   677  		// As an internal addition to the MessageStore interface, this
   678  		// also gets server messages when id=ClientID{}.
   679  		rs, err := tx.QueryContext(ctx, "SELECT "+
   680  			"m.message_id, "+
   681  			"m.source_client_id, "+
   682  			"m.source_service_name, "+
   683  			"m.source_message_id, "+
   684  			"m.destination_client_id, "+
   685  			"m.destination_service_name, "+
   686  			"m.message_type, "+
   687  			"m.creation_time_seconds, "+
   688  			"m.creation_time_nanos,"+
   689  			"m.validation_info,"+
   690  			"m.annotations,"+
   691  			"pm.retry_count, "+
   692  			"pm.data_type_url, "+
   693  			"pm.data_value "+
   694  			"FROM messages AS m, pending_messages AS pm "+
   695  			"WHERE m.destination_client_id = ? AND m.message_id=pm.message_id AND pm.scheduled_time < ? ",
   696  			toClientIDString(id.Bytes()), toMicro(db.Now()))
   697  		if err != nil {
   698  			return err
   699  		}
   700  		defer rs.Close()
   701  		for rs.Next() {
   702  			var dbm dbMessage
   703  			if err = rs.Scan(
   704  				&dbm.messageID,
   705  				&dbm.sourceClientID,
   706  				&dbm.sourceServiceName,
   707  				&dbm.sourceMessageID,
   708  				&dbm.destinationClientID,
   709  				&dbm.destinationServiceName,
   710  				&dbm.messageType,
   711  				&dbm.creationTimeSeconds,
   712  				&dbm.creationTimeNanos,
   713  				&dbm.validationInfo,
   714  				&dbm.annotations,
   715  				&dbm.retryCount,
   716  				&dbm.dataTypeURL,
   717  				&dbm.dataValue,
   718  			); err != nil {
   719  				return err
   720  			}
   721  			if serviceLimits != nil {
   722  				if read[dbm.destinationServiceName] >= serviceLimits[dbm.destinationServiceName] {
   723  					continue
   724  				} else {
   725  					read[dbm.destinationServiceName]++
   726  				}
   727  			}
   728  			nc := dbm.retryCount + 1
   729  			var due int64
   730  			if dbm.destinationClientID == "" {
   731  				due = toMicro(db.ServerRetryTime(nc))
   732  			} else {
   733  				due = toMicro(db.ClientRetryTime())
   734  			}
   735  			if _, err = tx.ExecContext(ctx, "UPDATE pending_messages SET retry_count=?, scheduled_time=? WHERE message_id=?", nc, due, dbm.messageID); err != nil {
   736  				return err
   737  			}
   738  			m, err := toMessageProto(&dbm)
   739  			if err != nil {
   740  				return err
   741  			}
   742  			res = append(res, m)
   743  			if len(res) >= int(lim) {
   744  				return nil
   745  			}
   746  		}
   747  		return rs.Err()
   748  	}); err != nil {
   749  		return nil, err
   750  	}
   751  	return res, nil
   752  }
   753  
   754  type messageLooper struct {
   755  	d *Datastore
   756  
   757  	mp               db.MessageProcessor
   758  	processingTicker *time.Ticker
   759  	stopCalled       chan struct{}
   760  	loopDone         chan struct{}
   761  }
   762  
   763  func (d *Datastore) RegisterMessageProcessor(mp db.MessageProcessor) {
   764  	if d.looper != nil {
   765  		log.Warning("Attempt to register a second MessageProcessor.")
   766  		d.looper.stop()
   767  	}
   768  	d.looper = &messageLooper{
   769  		d:                d,
   770  		mp:               mp,
   771  		processingTicker: time.NewTicker(300 * time.Millisecond),
   772  		stopCalled:       make(chan struct{}),
   773  		loopDone:         make(chan struct{}),
   774  	}
   775  	go d.looper.messageProcessingLoop()
   776  }
   777  
   778  func (d *Datastore) StopMessageProcessor() {
   779  	if d.looper != nil {
   780  		d.looper.stop()
   781  	}
   782  	d.looper = nil
   783  }
   784  
   785  // messageProcessingLoop reads messages that should be processed on the server
   786  // from the datastore and delivers them to the registered MessageProcessor.
   787  func (l *messageLooper) messageProcessingLoop() {
   788  	defer close(l.loopDone)
   789  	for {
   790  		select {
   791  		case <-l.stopCalled:
   792  			return
   793  		case <-l.processingTicker.C:
   794  			l.processMessages()
   795  		}
   796  	}
   797  }
   798  
   799  func (l *messageLooper) stop() {
   800  	l.processingTicker.Stop()
   801  	close(l.stopCalled)
   802  	<-l.loopDone
   803  }
   804  
   805  func (l *messageLooper) processMessages() {
   806  	for {
   807  		msgs, err := l.d.internalMessagesForProcessing(context.Background(), common.ClientID{}, 5, nil)
   808  		if err != nil {
   809  			if err.Error() == "attempt to write a readonly database" {
   810  				log.Errorf("Failed to read server messages for processing; probably the database was removed: %v", err)
   811  				return
   812  			}
   813  
   814  			log.Errorf("Failed to read server messages for processing: %v", err)
   815  			continue
   816  		}
   817  		l.mp.ProcessMessages(msgs)
   818  		if len(msgs) == 0 {
   819  			return
   820  		}
   821  	}
   822  }