github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/dbtesting/broadcaststore_suite.go (about)

     1  package dbtesting
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/google/fleetspeak/fleetspeak/src/common"
     9  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    10  	"github.com/google/fleetspeak/fleetspeak/src/server/ids"
    11  	"github.com/google/fleetspeak/fleetspeak/src/server/sertesting"
    12  	"google.golang.org/protobuf/proto"
    13  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    14  
    15  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    16  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    17  	anypb "google.golang.org/protobuf/types/known/anypb"
    18  )
    19  
    20  type idSet struct {
    21  	bID ids.BroadcastID
    22  	aID ids.AllocationID
    23  }
    24  
    25  func broadcastStoreTest(t *testing.T, ds db.Store) {
    26  	fakeTime := sertesting.FakeNow(10000)
    27  	defer fakeTime.Revert()
    28  
    29  	fin1 := sertesting.SetClientRetryTime(func() time.Time { return db.Now().Add(time.Minute) })
    30  	defer fin1()
    31  	fin2 := sertesting.SetServerRetryTime(func(_ uint32) time.Time { return db.Now().Add(time.Minute) })
    32  	defer fin2()
    33  
    34  	ctx := context.Background()
    35  
    36  	var bid []ids.BroadcastID
    37  
    38  	for _, s := range []string{"0000000000000000", "0000000000000001", "0000000000000002", "0000000000000003", "0000000000000004"} {
    39  		b, err := ids.StringToBroadcastID(s)
    40  		if err != nil {
    41  			t.Fatalf("BroadcastID(%v) failed: %v", s, err)
    42  		}
    43  		bid = append(bid, b)
    44  	}
    45  
    46  	future := tspb.New(time.Unix(200000, 0))
    47  
    48  	b0 := &spb.Broadcast{
    49  		BroadcastId: bid[0].Bytes(),
    50  		Source:      &fspb.Address{ServiceName: "testService"},
    51  		MessageType: "message type 1",
    52  		Data: &anypb.Any{
    53  			TypeUrl: "message proto name 1",
    54  			Value:   []byte("message data 1"),
    55  		},
    56  	}
    57  
    58  	for i, tc := range []struct {
    59  		br  *spb.Broadcast
    60  		lim uint64
    61  	}{
    62  		{
    63  			br:  b0,
    64  			lim: 8,
    65  		},
    66  		{
    67  			br: &spb.Broadcast{
    68  				BroadcastId:    bid[1].Bytes(),
    69  				Source:         &fspb.Address{ServiceName: "testService"},
    70  				ExpirationTime: future},
    71  			lim: 8,
    72  		},
    73  		{
    74  			br: &spb.Broadcast{
    75  				BroadcastId: bid[2].Bytes(),
    76  				Source:      &fspb.Address{ServiceName: "testService"}},
    77  			lim: 0, // (inactive)
    78  		},
    79  		{
    80  			br: &spb.Broadcast{
    81  				BroadcastId:    bid[3].Bytes(),
    82  				Source:         &fspb.Address{ServiceName: "testService"},
    83  				ExpirationTime: db.NowProto(), // just expired (inactive)
    84  			},
    85  			lim: 100,
    86  		},
    87  		{
    88  			br: &spb.Broadcast{
    89  				BroadcastId: bid[4].Bytes(),
    90  				Source:      &fspb.Address{ServiceName: "testService"},
    91  			},
    92  			lim: db.BroadcastUnlimited,
    93  		},
    94  	} {
    95  		if err := ds.CreateBroadcast(ctx, tc.br, tc.lim); err != nil {
    96  			t.Fatalf("%v: Unable to CreateBroadcast(%v): %v", i, tc.br, err)
    97  		}
    98  	}
    99  
   100  	bs, err := ds.ListActiveBroadcasts(ctx)
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	if len(bs) != 3 {
   105  		t.Errorf("Expected 2 active broadcasts, got: %v", bs)
   106  	}
   107  
   108  	// Advance the fake time past the expiration of the second broadcast.
   109  	fakeTime.SetSeconds(200001)
   110  	bs, err = ds.ListActiveBroadcasts(ctx)
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  	if len(bs) != 2 {
   115  		t.Errorf("Expected 2 active broadcasts, got: %v", bs)
   116  	} else {
   117  		if !(proto.Equal(bs[0].Broadcast, b0) || proto.Equal(bs[1].Broadcast, b0)) {
   118  			t.Errorf("ListActiveBroadcast=%v but want %v", bs[0], b0)
   119  		}
   120  	}
   121  
   122  	// Check that allocations are allocated expected number messages to send.
   123  	var allocs []idSet
   124  	for i, tc := range []struct {
   125  		id   ids.BroadcastID
   126  		frac float32
   127  		want uint64
   128  	}{
   129  		{id: bid[0], frac: 0.5, want: 4},
   130  		{id: bid[0], frac: 0.5, want: 2},
   131  		{id: bid[0], frac: 0.1, want: 1},
   132  		{id: bid[0], frac: 2.0, want: 1},
   133  		{id: bid[0], frac: 2.0, want: 0},
   134  		{id: bid[4], frac: 0.1, want: db.BroadcastUnlimited},
   135  	} {
   136  		a, err := ds.CreateAllocation(ctx, tc.id, tc.frac, db.Now().Add(5*time.Minute))
   137  		if err != nil {
   138  			t.Fatal(err)
   139  		}
   140  		if tc.want == 0 {
   141  			if a != nil {
   142  				t.Errorf("%v: Allocation(%v): wanted nil but got: %v", i, tc.id, a)
   143  				break
   144  			}
   145  			continue
   146  		}
   147  		if a == nil || a.Limit != tc.want {
   148  			t.Errorf("%v: Allocation(%v): wanted limit of %v but got: %v", i, tc.id, tc.want, a)
   149  		}
   150  		if a != nil {
   151  			allocs = append(allocs, idSet{tc.id, a.ID})
   152  		}
   153  	}
   154  
   155  	var clientID, _ = common.BytesToClientID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
   156  	if err := ds.AddClient(ctx, clientID, &db.ClientData{
   157  		Key: []byte("a client key"),
   158  	}); err != nil {
   159  		t.Fatal(err)
   160  	}
   161  
   162  	for _, ids := range []idSet{allocs[0], allocs[4]} {
   163  		mid, err := common.RandomMessageID()
   164  		if err != nil {
   165  			t.Fatal(err)
   166  		}
   167  		if err := ds.SaveBroadcastMessage(ctx, &fspb.Message{
   168  			MessageId: mid.Bytes(),
   169  			Destination: &fspb.Address{
   170  				ClientId:    clientID.Bytes(),
   171  				ServiceName: "testService",
   172  			},
   173  			Source: &fspb.Address{
   174  				ServiceName: "testService",
   175  			},
   176  			CreationTime: db.NowProto(),
   177  		}, ids.bID, clientID, ids.aID); err != nil {
   178  			t.Fatal(err)
   179  		}
   180  	}
   181  
   182  	// Clean them all up.
   183  	for _, ids := range allocs {
   184  		if err := ds.CleanupAllocation(ctx, ids.bID, ids.aID); err != nil {
   185  			t.Errorf("Unable to cleanup allocation %v: %v", ids, err)
   186  		}
   187  	}
   188  
   189  	// Fetch the active broadcasts again.
   190  	bs, err = ds.ListActiveBroadcasts(ctx)
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  	if len(bs) != 2 {
   195  		t.Errorf("Expected 2 active broadcasts, got: %v", bs)
   196  		t.FailNow()
   197  	}
   198  
   199  	if bs[0].Sent != 1 || bs[1].Sent != 1 {
   200  		t.Errorf("Expected broadcasts to show 1 sent message, got: %v", bs)
   201  	}
   202  
   203  	sb, err := ds.ListSentBroadcasts(ctx, clientID)
   204  	if err != nil {
   205  		t.Fatal(err)
   206  	}
   207  	if len(sb) != 2 {
   208  		t.Errorf("Expected 2 sent broadcast, got: %v", sb)
   209  		t.FailNow()
   210  	}
   211  	if sb[0] != bid[0] {
   212  		t.Errorf("Expected sent broadcast to be %v got: %v", bid[0], sb[0])
   213  	}
   214  }
   215  
   216  func broadcastStoreTestSuite(t *testing.T, env DbTestEnv) {
   217  	t.Run("BroadcastStoreTestSuite", func(t *testing.T) {
   218  		runTestSuite(t, env, map[string]func(*testing.T, db.Store){
   219  			"BroadcastStoreTest": broadcastStoreTest,
   220  		})
   221  	})
   222  }