github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/dbtesting/broadcaststore_suite.go (about) 1 package dbtesting 2 3 import ( 4 "context" 5 "testing" 6 "time" 7 8 "github.com/google/fleetspeak/fleetspeak/src/common" 9 "github.com/google/fleetspeak/fleetspeak/src/server/db" 10 "github.com/google/fleetspeak/fleetspeak/src/server/ids" 11 "github.com/google/fleetspeak/fleetspeak/src/server/sertesting" 12 "google.golang.org/protobuf/proto" 13 tspb "google.golang.org/protobuf/types/known/timestamppb" 14 15 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 16 spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server" 17 anypb "google.golang.org/protobuf/types/known/anypb" 18 ) 19 20 type idSet struct { 21 bID ids.BroadcastID 22 aID ids.AllocationID 23 } 24 25 func broadcastStoreTest(t *testing.T, ds db.Store) { 26 fakeTime := sertesting.FakeNow(10000) 27 defer fakeTime.Revert() 28 29 fin1 := sertesting.SetClientRetryTime(func() time.Time { return db.Now().Add(time.Minute) }) 30 defer fin1() 31 fin2 := sertesting.SetServerRetryTime(func(_ uint32) time.Time { return db.Now().Add(time.Minute) }) 32 defer fin2() 33 34 ctx := context.Background() 35 36 var bid []ids.BroadcastID 37 38 for _, s := range []string{"0000000000000000", "0000000000000001", "0000000000000002", "0000000000000003", "0000000000000004"} { 39 b, err := ids.StringToBroadcastID(s) 40 if err != nil { 41 t.Fatalf("BroadcastID(%v) failed: %v", s, err) 42 } 43 bid = append(bid, b) 44 } 45 46 future := tspb.New(time.Unix(200000, 0)) 47 48 b0 := &spb.Broadcast{ 49 BroadcastId: bid[0].Bytes(), 50 Source: &fspb.Address{ServiceName: "testService"}, 51 MessageType: "message type 1", 52 Data: &anypb.Any{ 53 TypeUrl: "message proto name 1", 54 Value: []byte("message data 1"), 55 }, 56 } 57 58 for i, tc := range []struct { 59 br *spb.Broadcast 60 lim uint64 61 }{ 62 { 63 br: b0, 64 lim: 8, 65 }, 66 { 67 br: &spb.Broadcast{ 68 BroadcastId: bid[1].Bytes(), 69 Source: &fspb.Address{ServiceName: "testService"}, 70 ExpirationTime: future}, 71 lim: 8, 72 }, 73 { 74 br: &spb.Broadcast{ 75 BroadcastId: bid[2].Bytes(), 76 Source: &fspb.Address{ServiceName: "testService"}}, 77 lim: 0, // (inactive) 78 }, 79 { 80 br: &spb.Broadcast{ 81 BroadcastId: bid[3].Bytes(), 82 Source: &fspb.Address{ServiceName: "testService"}, 83 ExpirationTime: db.NowProto(), // just expired (inactive) 84 }, 85 lim: 100, 86 }, 87 { 88 br: &spb.Broadcast{ 89 BroadcastId: bid[4].Bytes(), 90 Source: &fspb.Address{ServiceName: "testService"}, 91 }, 92 lim: db.BroadcastUnlimited, 93 }, 94 } { 95 if err := ds.CreateBroadcast(ctx, tc.br, tc.lim); err != nil { 96 t.Fatalf("%v: Unable to CreateBroadcast(%v): %v", i, tc.br, err) 97 } 98 } 99 100 bs, err := ds.ListActiveBroadcasts(ctx) 101 if err != nil { 102 t.Fatal(err) 103 } 104 if len(bs) != 3 { 105 t.Errorf("Expected 2 active broadcasts, got: %v", bs) 106 } 107 108 // Advance the fake time past the expiration of the second broadcast. 109 fakeTime.SetSeconds(200001) 110 bs, err = ds.ListActiveBroadcasts(ctx) 111 if err != nil { 112 t.Fatal(err) 113 } 114 if len(bs) != 2 { 115 t.Errorf("Expected 2 active broadcasts, got: %v", bs) 116 } else { 117 if !(proto.Equal(bs[0].Broadcast, b0) || proto.Equal(bs[1].Broadcast, b0)) { 118 t.Errorf("ListActiveBroadcast=%v but want %v", bs[0], b0) 119 } 120 } 121 122 // Check that allocations are allocated expected number messages to send. 123 var allocs []idSet 124 for i, tc := range []struct { 125 id ids.BroadcastID 126 frac float32 127 want uint64 128 }{ 129 {id: bid[0], frac: 0.5, want: 4}, 130 {id: bid[0], frac: 0.5, want: 2}, 131 {id: bid[0], frac: 0.1, want: 1}, 132 {id: bid[0], frac: 2.0, want: 1}, 133 {id: bid[0], frac: 2.0, want: 0}, 134 {id: bid[4], frac: 0.1, want: db.BroadcastUnlimited}, 135 } { 136 a, err := ds.CreateAllocation(ctx, tc.id, tc.frac, db.Now().Add(5*time.Minute)) 137 if err != nil { 138 t.Fatal(err) 139 } 140 if tc.want == 0 { 141 if a != nil { 142 t.Errorf("%v: Allocation(%v): wanted nil but got: %v", i, tc.id, a) 143 break 144 } 145 continue 146 } 147 if a == nil || a.Limit != tc.want { 148 t.Errorf("%v: Allocation(%v): wanted limit of %v but got: %v", i, tc.id, tc.want, a) 149 } 150 if a != nil { 151 allocs = append(allocs, idSet{tc.id, a.ID}) 152 } 153 } 154 155 var clientID, _ = common.BytesToClientID([]byte{0, 0, 0, 0, 0, 0, 0, 1}) 156 if err := ds.AddClient(ctx, clientID, &db.ClientData{ 157 Key: []byte("a client key"), 158 }); err != nil { 159 t.Fatal(err) 160 } 161 162 for _, ids := range []idSet{allocs[0], allocs[4]} { 163 mid, err := common.RandomMessageID() 164 if err != nil { 165 t.Fatal(err) 166 } 167 if err := ds.SaveBroadcastMessage(ctx, &fspb.Message{ 168 MessageId: mid.Bytes(), 169 Destination: &fspb.Address{ 170 ClientId: clientID.Bytes(), 171 ServiceName: "testService", 172 }, 173 Source: &fspb.Address{ 174 ServiceName: "testService", 175 }, 176 CreationTime: db.NowProto(), 177 }, ids.bID, clientID, ids.aID); err != nil { 178 t.Fatal(err) 179 } 180 } 181 182 // Clean them all up. 183 for _, ids := range allocs { 184 if err := ds.CleanupAllocation(ctx, ids.bID, ids.aID); err != nil { 185 t.Errorf("Unable to cleanup allocation %v: %v", ids, err) 186 } 187 } 188 189 // Fetch the active broadcasts again. 190 bs, err = ds.ListActiveBroadcasts(ctx) 191 if err != nil { 192 t.Fatal(err) 193 } 194 if len(bs) != 2 { 195 t.Errorf("Expected 2 active broadcasts, got: %v", bs) 196 t.FailNow() 197 } 198 199 if bs[0].Sent != 1 || bs[1].Sent != 1 { 200 t.Errorf("Expected broadcasts to show 1 sent message, got: %v", bs) 201 } 202 203 sb, err := ds.ListSentBroadcasts(ctx, clientID) 204 if err != nil { 205 t.Fatal(err) 206 } 207 if len(sb) != 2 { 208 t.Errorf("Expected 2 sent broadcast, got: %v", sb) 209 t.FailNow() 210 } 211 if sb[0] != bid[0] { 212 t.Errorf("Expected sent broadcast to be %v got: %v", bid[0], sb[0]) 213 } 214 } 215 216 func broadcastStoreTestSuite(t *testing.T, env DbTestEnv) { 217 t.Run("BroadcastStoreTestSuite", func(t *testing.T) { 218 runTestSuite(t, env, map[string]func(*testing.T, db.Store){ 219 "BroadcastStoreTest": broadcastStoreTest, 220 }) 221 }) 222 }