github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/dbtesting/messagestore_suite.go (about)

     1  package dbtesting
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	log "github.com/golang/glog"
    10  	"github.com/google/fleetspeak/fleetspeak/src/common"
    11  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    12  	"github.com/google/fleetspeak/fleetspeak/src/server/sertesting"
    13  	"google.golang.org/protobuf/proto"
    14  
    15  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    16  	anypb "google.golang.org/protobuf/types/known/anypb"
    17  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    18  )
    19  
    20  type idPair struct {
    21  	cid common.ClientID
    22  	mid common.MessageID
    23  }
    24  
    25  func storeGetMessagesTest(t *testing.T, ms db.Store) {
    26  	fakeTime := sertesting.FakeNow(100000)
    27  	defer fakeTime.Revert()
    28  
    29  	ctx := context.Background()
    30  
    31  	if err := ms.AddClient(ctx, clientID, &db.ClientData{Key: []byte("test key")}); err != nil {
    32  		t.Fatalf("AddClient [%v] failed: %v", clientID, err)
    33  	}
    34  
    35  	msgs := []*fspb.Message{
    36  		// A typical message to a client.
    37  		{
    38  			MessageId: []byte("01234567890123456789012345678902"),
    39  			Source: &fspb.Address{
    40  				ServiceName: "TestServiceName",
    41  			},
    42  			Destination: &fspb.Address{
    43  				ClientId:    clientID.Bytes(),
    44  				ServiceName: "TestServiceName",
    45  			},
    46  			MessageType:  "Test message type 2",
    47  			CreationTime: &tspb.Timestamp{Seconds: 42},
    48  			Data: &anypb.Any{
    49  				TypeUrl: "test data proto urn 2",
    50  				Value:   []byte("Test data proto 2")},
    51  		},
    52  		// Message with annotations.
    53  		{
    54  			MessageId: []byte("11234567890123456789012345678903"),
    55  			Source: &fspb.Address{
    56  				ServiceName: "TestServiceName",
    57  			},
    58  			Destination: &fspb.Address{
    59  				ClientId:    clientID.Bytes(),
    60  				ServiceName: "TestServiceName",
    61  			},
    62  			MessageType:  "Test message type 2",
    63  			CreationTime: &tspb.Timestamp{Seconds: 42},
    64  			Data: &anypb.Any{
    65  				TypeUrl: "test data proto urn 2",
    66  				Value:   []byte("Test data proto 2"),
    67  			},
    68  			Annotations: &fspb.Annotations{
    69  				Entries: []*fspb.Annotations_Entry{
    70  					{Key: "session_id", Value: "123"},
    71  					{Key: "request_id", Value: "1"},
    72  				},
    73  			},
    74  		},
    75  	}
    76  	// duplicate calls to StoreMessages shouldn't fail.
    77  	if err := ms.StoreMessages(ctx, msgs, ""); err != nil {
    78  		t.Fatal(err)
    79  	}
    80  	if err := ms.StoreMessages(ctx, msgs, ""); err != nil {
    81  		t.Error(err)
    82  	}
    83  
    84  	var idPairs []idPair
    85  	var messageIDs []common.MessageID
    86  	idMap := make(map[common.MessageID]*fspb.Message)
    87  	for _, m := range msgs {
    88  		mid, err := common.BytesToMessageID(m.MessageId)
    89  		if err != nil {
    90  			t.Fatal(err)
    91  		}
    92  		cid, err := common.BytesToClientID(m.Destination.ClientId)
    93  		if err != nil {
    94  			t.Fatal(err)
    95  		}
    96  		messageIDs = append(messageIDs, mid)
    97  		idPairs = append(idPairs, idPair{cid, mid})
    98  		idMap[mid] = m
    99  	}
   100  	msgsRead, err := ms.GetMessages(ctx, messageIDs, true)
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	if len(msgsRead) != len(msgs) {
   105  		t.Fatalf("Expected to read %v messages, got %v", len(messageIDs), msgsRead)
   106  	}
   107  
   108  	for _, m := range msgsRead {
   109  		id, err := common.BytesToMessageID(m.MessageId)
   110  		if err != nil {
   111  			t.Fatal(err)
   112  		}
   113  		if !proto.Equal(idMap[id], m) {
   114  			t.Errorf("Got %v but want %v when reading message id %v", m, idMap[id], id)
   115  		}
   116  	}
   117  
   118  	stat, err := ms.GetMessageResult(ctx, messageIDs[0])
   119  	if err != nil {
   120  		t.Errorf("unexpected error while retrieving message status: %v", err)
   121  	} else {
   122  		if stat != nil {
   123  			t.Errorf("GetMessageResult of unprocessed message: want [nil] got [%v]", stat)
   124  		}
   125  	}
   126  
   127  	fakeTime.SetSeconds(84)
   128  
   129  	for _, i := range idPairs {
   130  		if err := ms.SetMessageResult(ctx, i.cid, i.mid, &fspb.MessageResult{ProcessedTime: db.NowProto()}); err != nil {
   131  			t.Errorf("Unable to mark message %v as processed: %v", i, err)
   132  		}
   133  	}
   134  	msgsRead, err = ms.GetMessages(ctx, messageIDs, true)
   135  	if err != nil {
   136  		t.Fatal(err)
   137  	}
   138  	if len(msgsRead) != len(msgs) {
   139  		t.Fatalf("Expected to read %v messages, got %v", len(messageIDs), msgsRead)
   140  	}
   141  
   142  	for _, m := range msgsRead {
   143  		id, err := common.BytesToMessageID(m.MessageId)
   144  		if err != nil {
   145  			t.Fatal(err)
   146  		}
   147  		want := idMap[id]
   148  		want.Data = nil
   149  		want.Result = &fspb.MessageResult{
   150  			ProcessedTime: &tspb.Timestamp{Seconds: 84},
   151  		}
   152  		if !proto.Equal(want, m) {
   153  			t.Errorf("Got %v but want %v when reading message id %v", m, want, id)
   154  		}
   155  	}
   156  
   157  	stat, err = ms.GetMessageResult(ctx, messageIDs[0])
   158  	if err != nil {
   159  		t.Errorf("unexpected error while retrieving message result: %v", err)
   160  	} else {
   161  		want := &fspb.MessageResult{
   162  			ProcessedTime: &tspb.Timestamp{Seconds: 84},
   163  			Failed:        false,
   164  			FailedReason:  "",
   165  		}
   166  		if !proto.Equal(want, stat) {
   167  			t.Errorf("GetMessageResult error: want [%v] got [%v]", want, stat)
   168  		}
   169  	}
   170  }
   171  
   172  func clientMessagesForProcessingTest(t *testing.T, ms db.Store) {
   173  	fakeTime := sertesting.FakeNow(100000)
   174  	defer fakeTime.Revert()
   175  
   176  	ctx := context.Background()
   177  
   178  	if err := ms.AddClient(ctx, clientID, &db.ClientData{Key: []byte("test key")}); err != nil {
   179  		t.Fatalf("AddClient [%v] failed: %v", clientID, err)
   180  	}
   181  
   182  	mid2 := common.MakeMessageID(
   183  		&fspb.Address{
   184  			ClientId:    clientID.Bytes(),
   185  			ServiceName: "TestServiceName",
   186  		}, []byte("omid 2"))
   187  
   188  	stored := fspb.Message{
   189  		MessageId: mid2.Bytes(),
   190  		Source: &fspb.Address{
   191  			ServiceName: "TestServiceName",
   192  		},
   193  		SourceMessageId: []byte("omid 2"),
   194  		Destination: &fspb.Address{
   195  			ClientId:    clientID.Bytes(),
   196  			ServiceName: "TestServiceName",
   197  		},
   198  		CreationTime:   db.NowProto(),
   199  		ValidationInfo: &fspb.ValidationInfo{Tags: map[string]string{"result": "Valid"}},
   200  	}
   201  	err := ms.StoreMessages(ctx, []*fspb.Message{&stored}, "")
   202  	if err != nil {
   203  		t.Errorf("StoreMessages returned error: %v", err)
   204  	}
   205  
   206  	fakeTime.SetSeconds(300000)
   207  
   208  	m, err := ms.ClientMessagesForProcessing(ctx, clientID, 10, nil)
   209  	if err != nil {
   210  		t.Fatalf("ClientMessagesForProcessing(%v) returned error: %v", clientID, err)
   211  	}
   212  	log.Infof("Retrieved: %v", m)
   213  	if len(m) != 1 {
   214  		t.Errorf("ClientMessageForProcessing(%v) didn't return one message: %v", clientID, m)
   215  	}
   216  	if !proto.Equal(m[0], &stored) {
   217  		t.Errorf("ClientMessageForProcessing(%v) unexpected result, want: %v, got: %v", clientID, &stored, m[0])
   218  	}
   219  }
   220  
   221  func clientMessagesForProcessingLimitTest(t *testing.T, ms db.Store) {
   222  	ctx := context.Background()
   223  
   224  	if err := ms.AddClient(ctx, clientID, &db.ClientData{Key: []byte("test key")}); err != nil {
   225  		t.Fatalf("AddClient [%v] failed: %v", clientID, err)
   226  	}
   227  
   228  	// Create a backlog for 2 different services.
   229  	var toStore []*fspb.Message
   230  	for i := range 100 {
   231  		mid1 := common.MakeMessageID(
   232  			&fspb.Address{
   233  				ClientId:    clientID.Bytes(),
   234  				ServiceName: "TestService1",
   235  			}, []byte(fmt.Sprintf("omit: %d", i)))
   236  		toStore = append(toStore, &fspb.Message{
   237  			MessageId: mid1.Bytes(),
   238  			Source: &fspb.Address{
   239  				ServiceName: "TestService1",
   240  			},
   241  			Destination: &fspb.Address{
   242  				ClientId:    clientID.Bytes(),
   243  				ServiceName: "TestService1",
   244  			},
   245  			CreationTime: db.NowProto(),
   246  		})
   247  		mid2 := common.MakeMessageID(
   248  			&fspb.Address{
   249  				ClientId:    clientID.Bytes(),
   250  				ServiceName: "TestService2",
   251  			}, []byte(fmt.Sprintf("omit: %d", i)))
   252  		toStore = append(toStore, &fspb.Message{
   253  			MessageId: mid2.Bytes(),
   254  			Source: &fspb.Address{
   255  				ServiceName: "TestService2",
   256  			},
   257  			Destination: &fspb.Address{
   258  				ClientId:    clientID.Bytes(),
   259  				ServiceName: "TestService2",
   260  			},
   261  			CreationTime: db.NowProto(),
   262  		})
   263  	}
   264  
   265  	if err := ms.StoreMessages(ctx, toStore, ""); err != nil {
   266  		t.Errorf("StoreMessages returned error: %v", err)
   267  		return
   268  	}
   269  	for _, s := range []string{"TestService1", "TestService2"} {
   270  		m, err := ms.ClientMessagesForProcessing(ctx, clientID, 10, map[string]uint64{s: 5})
   271  		if err != nil {
   272  			t.Errorf("ClientMessagesForProcessing(10, %s=5) returned unexpected error: %v", s, err)
   273  			continue
   274  		}
   275  		if len(m) != 5 {
   276  			t.Errorf("ClientMessagesForProcessing(10, %s=5) returned %d messages, but expected 5.", s, len(m))
   277  		}
   278  		for _, v := range m {
   279  			if v.Destination.ServiceName != s {
   280  				t.Errorf("ClientMessagesForProcessing(10, %s=5) returned message with ServiceName=%s, but expected %s.", s, v.Destination.ServiceName, s)
   281  			}
   282  		}
   283  	}
   284  	m, err := ms.ClientMessagesForProcessing(ctx, clientID, 10, nil)
   285  	if err != nil {
   286  		t.Errorf("ClientMessagesForProcessing(10, nil) returned unexpected error: %v", err)
   287  		return
   288  	}
   289  	if len(m) != 10 {
   290  		t.Errorf("ClientMessagesForProcessing(10, nil) returned %d messages, but expected 5.", len(m))
   291  	}
   292  
   293  	// Get all messages remaining for processing, with limit.
   294  
   295  	for range 20 {
   296  		_, err := ms.ClientMessagesForProcessing(ctx, clientID, 10, nil)
   297  		if err != nil {
   298  			t.Fatalf("ClientMessagesForProcessing(10, nil) returned unexpected error: %v", err)
   299  		}
   300  	}
   301  
   302  	// There should be no messages left for processing.
   303  
   304  	m, err = ms.ClientMessagesForProcessing(ctx, clientID, 10, nil)
   305  	if err != nil {
   306  		t.Fatalf("ClientMessagesForProcessing(10, nil) returned unexpected error: %v", err)
   307  	}
   308  	if len(m) != 0 {
   309  		t.Fatalf("Exepcted 0 messages, got %v.", len(m))
   310  	}
   311  }
   312  
   313  func checkMessageResults(t *testing.T, ms db.Store, statuses map[common.MessageID]*fspb.MessageResult) {
   314  	for id, want := range statuses {
   315  		got, err := ms.GetMessageResult(context.Background(), id)
   316  		if err != nil {
   317  			t.Errorf("GetMessageResult(%v) returned error: %v", id, err)
   318  		} else {
   319  			if !proto.Equal(got, want) {
   320  				t.Errorf("GetMessageResult(%v)=[%v], but want [%v]", id, got, want)
   321  			}
   322  		}
   323  	}
   324  }
   325  
   326  func storeMessagesTest(t *testing.T, ms db.Store) {
   327  	fakeTime := sertesting.FakeNow(43)
   328  	defer fakeTime.Revert()
   329  
   330  	ctx := context.Background()
   331  
   332  	if err := ms.AddClient(ctx, clientID, &db.ClientData{Key: []byte("test key")}); err != nil {
   333  		t.Fatalf("AddClient [%v] failed: %v", clientID, err)
   334  	}
   335  	contact, err := ms.RecordClientContact(ctx, db.ContactData{
   336  		ClientID:      clientID,
   337  		NonceSent:     42,
   338  		NonceReceived: 43,
   339  		Addr:          "127.0.0.1"})
   340  	if err != nil {
   341  		t.Fatalf("RecordClientContact failed: %v", err)
   342  	}
   343  
   344  	// Create one message in each obvious state - new, processed, errored:
   345  	newID, _ := common.BytesToMessageID([]byte("01234567890123456789012345678906"))
   346  	processedID, _ := common.BytesToMessageID([]byte("01234567890123456789012345678907"))
   347  	erroredID, _ := common.BytesToMessageID([]byte("01234567890123456789012345678908"))
   348  
   349  	msgs := []*fspb.Message{
   350  		{
   351  			MessageId: newID.Bytes(),
   352  			Source: &fspb.Address{
   353  				ClientId:    clientID.Bytes(),
   354  				ServiceName: "TestServiceName",
   355  			},
   356  			Destination: &fspb.Address{
   357  				ServiceName: "TestServiceName",
   358  			},
   359  			MessageType:  "Test message type",
   360  			CreationTime: &tspb.Timestamp{Seconds: 42},
   361  			Data: &anypb.Any{
   362  				TypeUrl: "test data proto urn 2",
   363  				Value:   []byte("Test data proto 2")},
   364  		},
   365  		{
   366  			MessageId: processedID.Bytes(),
   367  			Source: &fspb.Address{
   368  				ClientId:    clientID.Bytes(),
   369  				ServiceName: "TestServiceName",
   370  			},
   371  			Destination: &fspb.Address{
   372  				ServiceName: "TestServiceName",
   373  			},
   374  			MessageType:  "Test message type",
   375  			CreationTime: &tspb.Timestamp{Seconds: 42},
   376  			Result: &fspb.MessageResult{
   377  				ProcessedTime: &tspb.Timestamp{Seconds: 42, Nanos: 20000},
   378  			},
   379  		},
   380  		// New message, will become errored.
   381  		{
   382  			MessageId: erroredID.Bytes(),
   383  			Source: &fspb.Address{
   384  				ClientId:    clientID.Bytes(),
   385  				ServiceName: "TestServiceName",
   386  			},
   387  			Destination: &fspb.Address{
   388  				ServiceName: "TestServiceName",
   389  			},
   390  			MessageType:  "Test message type",
   391  			CreationTime: &tspb.Timestamp{Seconds: 42},
   392  			Result: &fspb.MessageResult{
   393  				ProcessedTime: &tspb.Timestamp{Seconds: 42},
   394  				Failed:        true,
   395  				FailedReason:  "broken test message",
   396  			},
   397  		},
   398  	}
   399  	if err := ms.StoreMessages(ctx, msgs, contact); err != nil {
   400  		t.Fatal(err)
   401  	}
   402  	checkMessageResults(t, ms,
   403  		map[common.MessageID]*fspb.MessageResult{
   404  			newID: nil,
   405  			processedID: {
   406  				ProcessedTime: &tspb.Timestamp{Seconds: 42, Nanos: 20000}},
   407  			erroredID: {
   408  				ProcessedTime: &tspb.Timestamp{Seconds: 42},
   409  				Failed:        true,
   410  				FailedReason:  "broken test message",
   411  			},
   412  		})
   413  
   414  	// StoreMessages again, modeling that they were all resent, and that this time
   415  	// all processing completed.
   416  	for _, m := range msgs {
   417  		m.CreationTime = &tspb.Timestamp{Seconds: 52}
   418  		m.Result = &fspb.MessageResult{ProcessedTime: &tspb.Timestamp{Seconds: 52}}
   419  		m.Data = nil
   420  	}
   421  	if err := ms.StoreMessages(ctx, msgs, contact); err != nil {
   422  		t.Fatal(err)
   423  	}
   424  
   425  	checkMessageResults(t, ms,
   426  		map[common.MessageID]*fspb.MessageResult{
   427  			newID: {
   428  				ProcessedTime: &tspb.Timestamp{Seconds: 52}},
   429  			processedID: {
   430  				ProcessedTime: &tspb.Timestamp{Seconds: 52}},
   431  			erroredID: {
   432  				ProcessedTime: &tspb.Timestamp{Seconds: 52}},
   433  		})
   434  }
   435  
   436  func pendingMessagesTest(t *testing.T, ms db.Store) {
   437  	fakeTime := sertesting.FakeNow(43)
   438  	defer fakeTime.Revert()
   439  
   440  	ctx := context.Background()
   441  
   442  	if err := ms.AddClient(ctx, clientID, &db.ClientData{Key: []byte("test key")}); err != nil {
   443  		t.Fatalf("AddClient [%v] failed: %v", clientID, err)
   444  	}
   445  	contact, err := ms.RecordClientContact(ctx, db.ContactData{
   446  		ClientID:      clientID,
   447  		NonceSent:     42,
   448  		NonceReceived: 43,
   449  		Addr:          "127.0.0.1"})
   450  	if err != nil {
   451  		t.Fatalf("RecordClientContact failed: %v", err)
   452  	}
   453  
   454  	newID0, _ := common.BytesToMessageID([]byte("91234567890123456789012345678900"))
   455  	newID1, _ := common.BytesToMessageID([]byte("91234567890123456789012345678901"))
   456  	newID2, _ := common.BytesToMessageID([]byte("91234567890123456789012345678902"))
   457  	newID3, _ := common.BytesToMessageID([]byte("91234567890123456789012345678903"))
   458  	newID4, _ := common.BytesToMessageID([]byte("91234567890123456789012345678904"))
   459  
   460  	ids := []common.MessageID{
   461  		newID0,
   462  		newID1,
   463  		newID2,
   464  		newID3,
   465  		newID4,
   466  	}
   467  
   468  	msgs := []*fspb.Message{
   469  		{
   470  			MessageId: newID0.Bytes(),
   471  			Source: &fspb.Address{
   472  				ServiceName: "TestSource",
   473  			},
   474  			Destination: &fspb.Address{
   475  				ClientId:    clientID.Bytes(),
   476  				ServiceName: "TestServiceName",
   477  			},
   478  			MessageType:  "Test message type",
   479  			CreationTime: &tspb.Timestamp{Seconds: 42},
   480  			Data: &anypb.Any{
   481  				TypeUrl: "test data proto urn 0",
   482  				Value:   []byte("Test data proto 0")},
   483  		},
   484  		{
   485  			MessageId: newID1.Bytes(),
   486  			Source: &fspb.Address{
   487  				ServiceName: "TestSource",
   488  			},
   489  			Destination: &fspb.Address{
   490  				ClientId:    clientID.Bytes(),
   491  				ServiceName: "TestServiceName",
   492  			},
   493  			MessageType:  "Test message type",
   494  			CreationTime: &tspb.Timestamp{Seconds: 1},
   495  			Data: &anypb.Any{
   496  				TypeUrl: "test data proto urn 1",
   497  				Value:   []byte("Test data proto 1")},
   498  		},
   499  		{
   500  			MessageId: newID2.Bytes(),
   501  			Source: &fspb.Address{
   502  				ServiceName: "TestSource",
   503  			},
   504  			Destination: &fspb.Address{
   505  				ClientId:    clientID.Bytes(),
   506  				ServiceName: "TestServiceName",
   507  			},
   508  			MessageType:  "Test message type",
   509  			CreationTime: &tspb.Timestamp{Seconds: 2},
   510  			Data: &anypb.Any{
   511  				TypeUrl: "test data proto urn 2",
   512  				Value:   []byte("Test data proto 2")},
   513  		},
   514  		{
   515  			MessageId: newID3.Bytes(),
   516  			Source: &fspb.Address{
   517  				ServiceName: "TestSource",
   518  			},
   519  			Destination: &fspb.Address{
   520  				ClientId:    clientID.Bytes(),
   521  				ServiceName: "TestServiceName",
   522  			},
   523  			MessageType:  "Test message type",
   524  			CreationTime: &tspb.Timestamp{Seconds: 3},
   525  			Data: &anypb.Any{
   526  				TypeUrl: "test data proto urn 3",
   527  				Value:   []byte("Test data proto 3")},
   528  		},
   529  		{
   530  			MessageId: newID4.Bytes(),
   531  			Source: &fspb.Address{
   532  				ServiceName: "TestSource",
   533  			},
   534  			Destination: &fspb.Address{
   535  				ClientId:    clientID.Bytes(),
   536  				ServiceName: "TestServiceName",
   537  			},
   538  			MessageType:  "Test message type",
   539  			CreationTime: &tspb.Timestamp{Seconds: 4},
   540  			Data: &anypb.Any{
   541  				TypeUrl: "test data proto urn 4",
   542  				Value:   []byte("Test data proto 4")},
   543  		},
   544  	}
   545  	if err := ms.StoreMessages(ctx, msgs, contact); err != nil {
   546  		t.Fatal(err)
   547  	}
   548  
   549  	mc, err := ms.GetMessages(ctx, ids, false)
   550  	if err != nil {
   551  		t.Fatal(err)
   552  	}
   553  	if len(mc) != len(ids) {
   554  		t.Fatalf("Written message should be present in the store, not %v", len(mc))
   555  	}
   556  
   557  	t.Run("GetPendingMessageCount", func(t *testing.T) {
   558  		count, err := ms.GetPendingMessageCount(ctx, []common.ClientID{clientID})
   559  		if err != nil {
   560  			t.Fatal(err)
   561  		}
   562  		if count != uint64(len(msgs)) {
   563  			t.Fatalf("Bad pending messages count. Expected: %v. Got: %v.", len(msgs), count)
   564  		}
   565  	})
   566  
   567  	t.Run("GetPendingMessages/wantData=true", func(t *testing.T) {
   568  		pendingMsgs, err := ms.GetPendingMessages(ctx, []common.ClientID{clientID}, 0, 0, true)
   569  		if err != nil {
   570  			t.Fatal(err)
   571  		}
   572  		if len(pendingMsgs) != len(msgs) {
   573  			t.Fatalf("Expected %v pending messages, got %v", len(msgs), len(pendingMsgs))
   574  		}
   575  		for i := range msgs {
   576  			if !proto.Equal(msgs[i], pendingMsgs[i]) {
   577  				t.Fatalf("Expected pending message: [%v]. Got [%v].", msgs[i], pendingMsgs[i])
   578  			}
   579  		}
   580  	})
   581  
   582  	t.Run("GetPendingMessages/wantData=true/offset/limit", func(t *testing.T) {
   583  		pendingMsgs, err := ms.GetPendingMessages(ctx, []common.ClientID{clientID}, 1, 2, true)
   584  		if err != nil {
   585  			t.Fatal(err)
   586  		}
   587  		if len(pendingMsgs) != 2 {
   588  			t.Fatalf("Expected %v pending messages, got %v", 2, len(pendingMsgs))
   589  		}
   590  		for i := range pendingMsgs {
   591  			if !proto.Equal(msgs[1+i], pendingMsgs[i]) {
   592  				t.Fatalf("Expected pending message: [%v]. Got [%v].", msgs[1+i], pendingMsgs[i])
   593  			}
   594  		}
   595  	})
   596  
   597  	t.Run("GetPendingMessages/wantData=true/limit", func(t *testing.T) {
   598  		pendingMsgs, err := ms.GetPendingMessages(ctx, []common.ClientID{clientID}, 0, 1, true)
   599  		if err != nil {
   600  			t.Fatal(err)
   601  		}
   602  		if len(pendingMsgs) != 1 {
   603  			t.Fatalf("Expected %v pending messages, got %v", 1, len(pendingMsgs))
   604  		}
   605  		for i := range pendingMsgs {
   606  			if !proto.Equal(msgs[i], pendingMsgs[i]) {
   607  				t.Fatalf("Expected pending message: [%v]. Got [%v].", msgs[i], pendingMsgs[i])
   608  			}
   609  		}
   610  	})
   611  
   612  	t.Run("GetPendingMessages/wantData=true/offset", func(t *testing.T) {
   613  		_, err := ms.GetPendingMessages(ctx, []common.ClientID{clientID}, 1, 0, true)
   614  		if err == nil {
   615  			t.Fatal("Expected to get error, but got none.")
   616  		}
   617  	})
   618  
   619  	t.Run("GetPendingMessages/wantData=false", func(t *testing.T) {
   620  		pendingMsgs, err := ms.GetPendingMessages(ctx, []common.ClientID{clientID}, 0, 0, false)
   621  		if err != nil {
   622  			t.Fatal(err)
   623  		}
   624  		if len(pendingMsgs) != len(msgs) {
   625  			t.Fatalf("Expected %v pending message, got %v", len(msgs), len(pendingMsgs))
   626  		}
   627  		for i := range msgs {
   628  			expectedMsg := proto.Clone(msgs[i]).(*fspb.Message)
   629  			expectedMsg.Data = nil
   630  			if !proto.Equal(expectedMsg, pendingMsgs[i]) {
   631  				t.Fatalf("Expected pending message: [%v]. Got [%v].", expectedMsg, pendingMsgs[i])
   632  			}
   633  		}
   634  	})
   635  
   636  	t.Run("DeletePendingMessages", func(t *testing.T) {
   637  		if err := ms.DeletePendingMessages(ctx, []common.ClientID{clientID}); err != nil {
   638  			t.Fatal(err)
   639  		}
   640  
   641  		mc, err = ms.ClientMessagesForProcessing(ctx, clientID, 1, nil)
   642  		if err != nil {
   643  			t.Fatal(err)
   644  		}
   645  		if len(mc) != 0 {
   646  			t.Fatalf("No messages for processing were expected, found: %v", mc)
   647  		}
   648  
   649  		for _, id := range ids {
   650  			mr, err := ms.GetMessageResult(ctx, id)
   651  			if err != nil {
   652  				t.Fatal(err)
   653  			}
   654  			if mr == nil {
   655  				t.Fatal("Message result must be in the store after pending message is deleted", mr)
   656  			}
   657  			if !mr.Failed {
   658  				t.Errorf("Expected the message to have failed=true, got: %v", mr.Failed)
   659  			}
   660  			if mr.FailedReason != "Removed by admin action." {
   661  				t.Errorf("Expected the message to have failure reason 'Removed by admin action', got: %v", mr.FailedReason)
   662  			}
   663  		}
   664  	})
   665  }
   666  
   667  type fakeMessageProcessor struct {
   668  	c chan *fspb.Message
   669  }
   670  
   671  func (p *fakeMessageProcessor) ProcessMessages(msgs []*fspb.Message) {
   672  	ctx, c := context.WithTimeout(context.Background(), 5*time.Second)
   673  	defer c()
   674  	for _, m := range msgs {
   675  		select {
   676  		case p.c <- m:
   677  		case <-ctx.Done():
   678  		}
   679  	}
   680  }
   681  
   682  func registerMessageProcessorTest(t *testing.T, ms db.Store) {
   683  	fakeTime := sertesting.FakeNow(100000)
   684  	defer fakeTime.Revert()
   685  
   686  	ctx := context.Background()
   687  	if err := ms.AddClient(ctx, clientID, &db.ClientData{Key: []byte("test key")}); err != nil {
   688  		t.Fatalf("AddClient [%v] failed: %v", clientID, err)
   689  	}
   690  
   691  	p := fakeMessageProcessor{
   692  		c: make(chan *fspb.Message, 1),
   693  	}
   694  	ms.RegisterMessageProcessor(&p)
   695  	defer ms.StopMessageProcessor()
   696  
   697  	msg := &fspb.Message{
   698  		MessageId: []byte("01234567890123456789012345678903"),
   699  		Source: &fspb.Address{
   700  			ClientId:    clientID.Bytes(),
   701  			ServiceName: "TestServiceName",
   702  		},
   703  		SourceMessageId: []byte("01234567"),
   704  		Destination: &fspb.Address{
   705  			ServiceName: "TestServiceName",
   706  		},
   707  		MessageType:  "Test message type 1",
   708  		CreationTime: &tspb.Timestamp{Seconds: 42},
   709  		Data: &anypb.Any{
   710  			TypeUrl: "test data proto urn 1",
   711  			Value:   []byte("Test data proto 1"),
   712  		},
   713  	}
   714  	if err := ms.StoreMessages(ctx, []*fspb.Message{msg}, ""); err != nil {
   715  		t.Fatal(err)
   716  	}
   717  
   718  	// If we advance the clock 120 seconds (> than that added by
   719  	// testRetryPolicy), the message should be processed. See the implementation
   720  	// of ServerRetryTime (retry.go) for details.
   721  	fakeTime.AddSeconds(121)
   722  
   723  	select {
   724  	case <-time.After(20 * time.Second):
   725  		t.Errorf("Did not receive notification of message to process after 20 seconds")
   726  	case m := <-p.c:
   727  		if !proto.Equal(m, msg) {
   728  			t.Errorf("ProcessMessage called with wrong message, got: [%v] want: [%v]", m, msg)
   729  		}
   730  		mid, err := common.BytesToMessageID(msg.MessageId)
   731  		if err != nil {
   732  			t.Fatal(err)
   733  		}
   734  		if err := ms.SetMessageResult(ctx, common.ClientID{}, mid, &fspb.MessageResult{ProcessedTime: db.NowProto()}); err != nil {
   735  			t.Errorf("Unable to mark message as processed: %v", err)
   736  		}
   737  	}
   738  }
   739  
   740  func messageStoreTestSuite(t *testing.T, env DbTestEnv) {
   741  	t.Run("MessageStoreTestSuite", func(t *testing.T) {
   742  		runTestSuite(t, env, map[string]func(*testing.T, db.Store){
   743  			"StoreGetMessagesTest":                 storeGetMessagesTest,
   744  			"StoreMessagesTest":                    storeMessagesTest,
   745  			"PendingMessagesTest":                  pendingMessagesTest,
   746  			"ClientMessagesForProcessingTest":      clientMessagesForProcessingTest,
   747  			"ClientMessagesForProcessingLimitTest": clientMessagesForProcessingLimitTest,
   748  			"RegisterMessageProcessor":             registerMessageProcessorTest,
   749  		})
   750  	})
   751  }