github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/dbtesting/clientstore_suite.go (about) 1 package dbtesting 2 3 import ( 4 "bytes" 5 "context" 6 "reflect" 7 "sort" 8 "testing" 9 "time" 10 11 "github.com/google/fleetspeak/fleetspeak/src/common" 12 "github.com/google/fleetspeak/fleetspeak/src/server/db" 13 "github.com/google/fleetspeak/fleetspeak/src/server/sertesting" 14 "google.golang.org/protobuf/proto" 15 16 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 17 mpb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak_monitoring" 18 spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server" 19 anypb "google.golang.org/protobuf/types/known/anypb" 20 tspb "google.golang.org/protobuf/types/known/timestamppb" 21 ) 22 23 type labelSorter struct { 24 l []*fspb.Label 25 } 26 27 func (l labelSorter) Sort() { 28 sort.Sort(l) 29 } 30 31 func (l labelSorter) Len() int { 32 return len(l.l) 33 } 34 35 func (l labelSorter) Less(i, j int) bool { 36 switch { 37 case l.l[i].ServiceName < l.l[j].ServiceName: 38 return true 39 case l.l[i].ServiceName > l.l[j].ServiceName: 40 return false 41 } 42 return l.l[i].Label < l.l[j].Label 43 } 44 45 func (l labelSorter) Swap(i, j int) { 46 t := l.l[j] 47 l.l[j] = l.l[i] 48 l.l[i] = t 49 } 50 51 func clientDataEqual(a, b *db.ClientData) bool { 52 switch { 53 case a == nil && b == nil: 54 return true 55 case a != nil && b == nil: 56 return false 57 case a == nil && b != nil: 58 return false 59 case a != nil && b != nil: 60 if !bytes.Equal(a.Key, b.Key) { 61 return false 62 } 63 if len(a.Labels) != len(b.Labels) { 64 return false 65 } 66 labelSorter{a.Labels}.Sort() 67 labelSorter{b.Labels}.Sort() 68 69 l := len(a.Labels) 70 if l != len(b.Labels) { 71 return false 72 } 73 for i := range l { 74 if !proto.Equal(a.Labels[i], b.Labels[i]) { 75 return false 76 } 77 } 78 79 return true 80 } 81 return false 82 } 83 84 func clientStoreTest(t *testing.T, ds db.Store) { 85 fakeTime := sertesting.FakeNow(84) 86 defer fakeTime.Revert() 87 88 fin1 := sertesting.SetClientRetryTime(func() time.Time { return db.Now().Add(time.Minute) }) 89 defer fin1() 90 fin2 := sertesting.SetServerRetryTime(func(_ uint32) time.Time { return db.Now().Add(time.Minute) }) 91 defer fin2() 92 93 ctx := context.Background() 94 key := []byte("A binary client key \x00\xff\x01\xfe") 95 96 for _, tc := range []struct { 97 desc string 98 op func() error 99 cd *db.ClientData 100 wantNotFound bool 101 }{ 102 { 103 desc: "missing client", 104 op: func() error { return nil }, 105 cd: nil, 106 wantNotFound: true, 107 }, 108 { 109 desc: "client added", 110 op: func() error { 111 return ds.AddClient(ctx, clientID, &db.ClientData{ 112 Key: key, 113 Labels: []*fspb.Label{ 114 {ServiceName: "system", Label: "Windows"}, 115 {ServiceName: "system", Label: "client-version-0.01"}}}) 116 }, 117 cd: &db.ClientData{ 118 Key: key, 119 Labels: []*fspb.Label{ 120 {ServiceName: "system", Label: "Windows"}, 121 {ServiceName: "system", Label: "client-version-0.01"}}, 122 }, 123 }, 124 { 125 desc: "label added", 126 op: func() error { 127 return ds.AddClientLabel(ctx, clientID, &fspb.Label{ServiceName: "system", Label: "new label"}) 128 }, 129 cd: &db.ClientData{ 130 Key: key, 131 Labels: []*fspb.Label{ 132 {ServiceName: "system", Label: "new label"}, 133 {ServiceName: "system", Label: "Windows"}, 134 {ServiceName: "system", Label: "client-version-0.01"}}, 135 }, 136 }, 137 { 138 desc: "label removed", 139 op: func() error { 140 return ds.RemoveClientLabel(ctx, clientID, &fspb.Label{ServiceName: "system", Label: "client-version-0.01"}) 141 }, 142 cd: &db.ClientData{ 143 Key: key, 144 Labels: []*fspb.Label{ 145 {ServiceName: "system", Label: "new label"}, 146 {ServiceName: "system", Label: "Windows"}}, 147 }, 148 }, 149 } { 150 err := tc.op() 151 if err != nil { 152 t.Errorf("%s: got unexpected error performing op: %v", tc.desc, err) 153 } 154 cd, err := ds.GetClientData(ctx, clientID) 155 if tc.wantNotFound { 156 if !ds.IsNotFound(err) { 157 t.Errorf("%s: got %v but want not found error after performing op", tc.desc, err) 158 } 159 } else { 160 if !clientDataEqual(cd, tc.cd) { 161 t.Errorf("%s: got %v want %v after performing op", tc.desc, cd, tc.cd) 162 } 163 } 164 } 165 longAddr := "[ABCD:ABCD:ABCD:ABCD:ABCD:ABCD:192.168.123.123]:65535" 166 contactID, err := ds.RecordClientContact(ctx, db.ContactData{ 167 ClientID: clientID, 168 NonceSent: 42, 169 NonceReceived: 54, 170 Addr: longAddr, 171 ClientClock: &tspb.Timestamp{Seconds: 21}, 172 StreamingTo: "fs.servers.somewhere.com", 173 }) 174 if err != nil { 175 t.Errorf("unexpected error for RecordClientContact: %v", err) 176 } 177 178 if err := ds.StoreMessages(ctx, []*fspb.Message{ 179 { 180 MessageId: []byte("01234567890123456789012345678901"), 181 Source: &fspb.Address{ 182 ClientId: clientID.Bytes(), 183 ServiceName: "TestServiceName", 184 }, 185 SourceMessageId: []byte("01234567"), 186 Destination: &fspb.Address{ 187 ServiceName: "TestServiceName", 188 }, 189 MessageType: "Test message type 1", 190 CreationTime: &tspb.Timestamp{Seconds: 42}, 191 Data: &anypb.Any{ 192 TypeUrl: "test data proto urn 1", 193 Value: []byte("Test data proto 1"), 194 }}, 195 }, ""); err != nil { 196 t.Errorf("unexpected error for StoreMessage: %v", err) 197 } 198 199 mid, err := common.BytesToMessageID([]byte("01234567890123456789012345678901")) 200 if err != nil { 201 t.Fatalf("unexpected error for BytesToMessageID: %v", err) 202 } 203 204 if err := ds.LinkMessagesToContact(ctx, contactID, []common.MessageID{mid}); err != nil { 205 t.Errorf("unexpected error linking message to contact: %v", err) 206 } 207 208 clients, err := ds.ListClients(ctx, nil) 209 if err != nil { 210 t.Errorf("unexpected error while listing client ids: %v", err) 211 return 212 } 213 if len(clients) != 1 { 214 t.Errorf("expected ListClients to return 1 entry, got %v", len(clients)) 215 return 216 } 217 got := clients[0] 218 // Some datastores might not respect db.Now for LastContactTime. If it seems 219 // to be a current timestamp (2017-2030) assume it is fine, and adjust to the 220 // expected value. 221 adjustDbTimestamp := func(timestamp *tspb.Timestamp) { 222 if timestamp.Seconds > 1483228800 && timestamp.Seconds < 1893456000 { 223 *timestamp = tspb.Timestamp{Seconds: 84} 224 } 225 } 226 adjustDbTimestamp(got.LastContactTime) 227 want := &spb.Client{ 228 ClientId: clientID.Bytes(), 229 Labels: []*fspb.Label{ 230 { 231 ServiceName: "system", 232 Label: "Windows", 233 }, 234 { 235 ServiceName: "system", 236 Label: "new label", 237 }, 238 }, 239 LastContactTime: &tspb.Timestamp{Seconds: 84}, 240 LastContactStreamingTo: "fs.servers.somewhere.com", 241 LastContactAddress: longAddr, 242 LastClock: &tspb.Timestamp{Seconds: 21}, 243 } 244 245 labelSorter{got.Labels}.Sort() 246 labelSorter{want.Labels}.Sort() 247 248 if !proto.Equal(want, got) { 249 t.Errorf("ListClients error: want [%v] got [%v]", want, got) 250 } 251 252 checkClientContacts := func(t *testing.T, contacts []*spb.ClientContact) { 253 if len(contacts) != 1 { 254 t.Errorf("ListClientContacts returned %d results, expected 1.", len(contacts)) 255 } else { 256 if contacts[0].SentNonce != 42 || contacts[0].ReceivedNonce != 54 { 257 t.Errorf("ListClientContact[0] should return nonces (42, 54), got (%d, %d)", 258 contacts[0].SentNonce, contacts[0].ReceivedNonce) 259 } 260 if contacts[0].ObservedAddress != longAddr { 261 t.Errorf("ListClientContact[0] should return address %s, got %s", 262 longAddr, contacts[0].ObservedAddress) 263 } 264 } 265 } 266 267 t.Run("ListClientContacts", func(t *testing.T) { 268 contacts, err := ds.ListClientContacts(ctx, clientID) 269 if err != nil { 270 t.Errorf("ListClientContacts returned error: %v", err) 271 } 272 checkClientContacts(t, contacts) 273 }) 274 275 t.Run("StreamClientContacts", func(t *testing.T) { 276 var contacts []*spb.ClientContact 277 callback := func(contact *spb.ClientContact) error { 278 contacts = append(contacts, contact) 279 return nil 280 } 281 err := ds.StreamClientContacts(ctx, clientID, callback) 282 if err != nil { 283 t.Errorf("StreamClientContacts returned error: %v", err) 284 } 285 checkClientContacts(t, contacts) 286 }) 287 288 if err := ds.BlacklistClient(ctx, clientID); err != nil { 289 t.Errorf("Error blacklisting client: %v", err) 290 } 291 g, err := ds.GetClientData(ctx, clientID) 292 if err != nil { 293 t.Errorf("Error getting client data after blacklisting: %v", err) 294 } 295 w := &db.ClientData{ 296 Key: key, 297 Labels: []*fspb.Label{ 298 {ServiceName: "system", Label: "Windows"}, 299 {ServiceName: "system", Label: "new label"}}, 300 Blacklisted: true, 301 } 302 if !clientDataEqual(g, w) { 303 t.Errorf("Got %+v want %+v after blacklisting client.", g, w) 304 } 305 } 306 307 func addClientsTest(t *testing.T, ds db.Store) { 308 ctx := context.Background() 309 key := []byte("A binary client key \x00\xff\x01\xfe") 310 311 t.Run("add non-blacklisted client", func(t *testing.T) { 312 if err := ds.AddClient(ctx, clientID, &db.ClientData{ 313 Key: key, 314 }); err != nil { 315 t.Errorf("Can't add client.") 316 } 317 318 got, err := ds.GetClientData(ctx, clientID) 319 if err != nil { 320 t.Errorf("Can't get client data.") 321 } 322 323 expected := &db.ClientData{ 324 Key: key, 325 } 326 if !reflect.DeepEqual(expected, got) { 327 t.Errorf("Expected %v, got %v", expected, got) 328 } 329 }) 330 331 t.Run("add blacklisted client", func(t *testing.T) { 332 if err := ds.AddClient(ctx, clientID2, &db.ClientData{ 333 Key: key, 334 Blacklisted: true, 335 }); err != nil { 336 t.Errorf("Can't add client.") 337 } 338 339 got, err := ds.GetClientData(ctx, clientID2) 340 if err != nil { 341 t.Errorf("Can't get client data.") 342 } 343 344 expected := &db.ClientData{ 345 Key: key, 346 Blacklisted: true, 347 } 348 if !reflect.DeepEqual(expected, got) { 349 t.Errorf("Expected %v, got %v", expected, got) 350 } 351 }) 352 } 353 354 func listClientsTest(t *testing.T, ds db.Store) { 355 ctx := context.Background() 356 357 for _, cid := range [...]common.ClientID{clientID, clientID2, clientID3} { 358 if err := ds.AddClient(ctx, cid, &db.ClientData{Key: []byte("test key")}); err != nil { 359 t.Fatalf("AddClient [%v] failed: %v", clientID, err) 360 } 361 } 362 363 if err := ds.BlacklistClient(ctx, clientID3); err != nil { 364 t.Errorf("Unable to blacklist client: %v", err) 365 } 366 Cases: 367 for _, tc := range []struct { 368 name string 369 ids []common.ClientID 370 want map[common.ClientID]bool 371 wantBlacklisted map[common.ClientID]bool 372 }{ 373 { 374 ids: nil, 375 want: map[common.ClientID]bool{clientID: true, clientID2: true, clientID3: true}, 376 wantBlacklisted: map[common.ClientID]bool{clientID3: true}, 377 }, 378 { 379 ids: []common.ClientID{clientID}, 380 want: map[common.ClientID]bool{clientID: true}, 381 wantBlacklisted: map[common.ClientID]bool{}, 382 }, 383 { 384 ids: []common.ClientID{clientID, clientID2}, 385 want: map[common.ClientID]bool{clientID: true, clientID2: true}, 386 wantBlacklisted: map[common.ClientID]bool{}, 387 }, 388 } { 389 clients, err := ds.ListClients(ctx, tc.ids) 390 if err != nil { 391 t.Errorf("unexpected error while listing client ids [%v]: %v", tc.ids, err) 392 continue Cases 393 } 394 got := make(map[common.ClientID]bool) 395 gotBlacklisted := make(map[common.ClientID]bool) 396 for _, c := range clients { 397 id, err := common.BytesToClientID(c.ClientId) 398 if err != nil { 399 t.Errorf("ListClients(%v) returned invalid client_id: %v", tc.ids, err) 400 } 401 if c.LastContactTime == nil { 402 t.Errorf("ListClients(%v) returned nil LastContactTime.", tc.ids) 403 } 404 got[id] = true 405 if c.Blacklisted { 406 gotBlacklisted[id] = true 407 } 408 } 409 if !reflect.DeepEqual(tc.want, got) { 410 t.Errorf("ListClients(%v) returned unexpected set of clients, want [%v], got[%v]", tc.ids, tc.want, got) 411 } 412 if !reflect.DeepEqual(tc.wantBlacklisted, gotBlacklisted) { 413 t.Errorf("ListClients(%v) returned unexpected set of blacklisted clients, want [%v], got[%v]", tc.ids, tc.wantBlacklisted, gotBlacklisted) 414 } 415 } 416 } 417 418 func streamClientIdsTest(t *testing.T, ds db.Store) { 419 ctx := context.Background() 420 421 clientIds := []common.ClientID{clientID, clientID2, clientID3} 422 contactTimes := []time.Time{} 423 424 for idx, cid := range clientIds { 425 contactTimes = append(contactTimes, db.Now()) 426 if err := ds.AddClient(ctx, cid, &db.ClientData{Key: []byte("test key"), Blacklisted: idx%2 != 0}); err != nil { 427 t.Fatalf("AddClient [%v] failed: %v", clientID, err) 428 } 429 } 430 431 t.Run("Stream all clients", func(t *testing.T) { 432 var result []common.ClientID 433 434 callback := func(id common.ClientID) error { 435 result = append(result, id) 436 return nil 437 } 438 439 err := ds.StreamClientIds(ctx, true, nil, callback) 440 if err != nil { 441 t.Fatalf("StreamClientIds failed: %v", err) 442 } 443 444 sort.Slice(result, func(i int, j int) bool { 445 return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 446 }) 447 448 if !reflect.DeepEqual(result, clientIds) { 449 t.Errorf("StreamClientIds returned unexpected result. Got: [%v]. Want: [%v].", result, clientIds) 450 } 451 }) 452 453 t.Run("Stream non-blacklisted clients only", func(t *testing.T) { 454 var result []common.ClientID 455 456 callback := func(id common.ClientID) error { 457 result = append(result, id) 458 return nil 459 } 460 461 err := ds.StreamClientIds(ctx, false, nil, callback) 462 if err != nil { 463 t.Fatalf("StreamClientIds failed: %v", err) 464 } 465 466 sort.Slice(result, func(i int, j int) bool { 467 return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 468 }) 469 470 expected := []common.ClientID{clientID, clientID3} 471 if !reflect.DeepEqual(result, expected) { 472 t.Errorf("StreamClientIds returned unexpected result. Got: [%v]. Want: [%v].", result, expected) 473 } 474 }) 475 476 t.Run("Stream all clients with time filter", func(t *testing.T) { 477 var result []common.ClientID 478 479 callback := func(id common.ClientID) error { 480 result = append(result, id) 481 return nil 482 } 483 484 err := ds.StreamClientIds(ctx, true, &contactTimes[1], callback) 485 if err != nil { 486 t.Fatalf("StreamClientIds failed: %v", err) 487 } 488 489 sort.Slice(result, func(i int, j int) bool { 490 return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 491 }) 492 493 expected := []common.ClientID{clientID2, clientID3} 494 if !reflect.DeepEqual(result, expected) { 495 t.Errorf("StreamClientIds returned unexpected result. Got: [%v]. Want: [%v].", result, expected) 496 } 497 }) 498 } 499 500 func fetchResourceUsageRecordsTest(t *testing.T, ds db.Store) { 501 ctx := context.Background() 502 key := []byte("Test key") 503 err := ds.AddClient(ctx, clientID, &db.ClientData{ 504 Key: key}) 505 if err != nil { 506 t.Errorf("add client: got unexpected error performing op: %v", err) 507 } 508 509 meanRAM, maxRAM := 190, 200 510 rud := mpb.ResourceUsageData{ 511 Scope: "test-scope", 512 Pid: 1234, 513 ProcessStartTime: &tspb.Timestamp{Seconds: 1234567890, Nanos: 98765}, 514 DataTimestamp: &tspb.Timestamp{Seconds: 1234567891, Nanos: 98765}, 515 ProcessTerminated: true, 516 ResourceUsage: &mpb.AggregatedResourceUsage{ 517 MeanUserCpuRate: 50.0, 518 MaxUserCpuRate: 60.0, 519 MeanSystemCpuRate: 70.0, 520 MaxSystemCpuRate: 80.0, 521 MeanResidentMemory: float64(meanRAM) * 1024 * 1024, 522 MaxResidentMemory: int64(maxRAM) * 1024 * 1024, 523 MeanNumFds: 13.4, 524 MaxNumFds: 42, 525 }, 526 } 527 528 beforeRecordTime := db.Now() 529 oneMinAfterRecordTime := beforeRecordTime.Add(time.Minute) // This isn't exactly 1 minute after record time. 530 beforeRecordTimestamp := tspb.New(beforeRecordTime) 531 if err := beforeRecordTimestamp.CheckValid(); err != nil { 532 t.Fatalf("Invalid time.Time object cannot be converted to tpb.Timestamp: %v", err) 533 } 534 afterRecordTimestamp := tspb.New(oneMinAfterRecordTime) 535 if err := afterRecordTimestamp.CheckValid(); err != nil { 536 t.Fatalf("Invalid time.Time object cannot be converted to tpb.Timestamp: %v", err) 537 } 538 539 err = ds.RecordResourceUsageData(ctx, clientID, &rud) 540 if err != nil { 541 t.Fatalf("Unexpected error when writing client resource-usage data: %v", err) 542 } 543 544 records, err := ds.FetchResourceUsageRecords(ctx, clientID, beforeRecordTimestamp, afterRecordTimestamp) 545 if err != nil { 546 t.Errorf("Unexpected error when trying to fetch resource-usage data for client: %v", err) 547 } 548 if len(records) != 1 { 549 t.Fatalf("Unexpected number of records returned. Want %d, got %v", 1, len(records)) 550 } 551 552 record := records[0] 553 expected := &spb.ClientResourceUsageRecord{ 554 Scope: "test-scope", 555 Pid: 1234, 556 ProcessStartTime: &tspb.Timestamp{Seconds: 1234567890, Nanos: 98765}, 557 ClientTimestamp: &tspb.Timestamp{Seconds: 1234567891, Nanos: 98765}, 558 ServerTimestamp: record.ServerTimestamp, 559 ProcessTerminated: true, 560 MeanUserCpuRate: 50.0, 561 MaxUserCpuRate: 60.0, 562 MeanSystemCpuRate: 70.0, 563 MaxSystemCpuRate: 80.0, 564 MeanResidentMemoryMib: int32(meanRAM), 565 MaxResidentMemoryMib: int32(maxRAM), 566 MeanNumFds: 13, 567 MaxNumFds: 42, 568 } 569 570 if got, want := record, expected; !proto.Equal(got, want) { 571 t.Errorf("Resource-usage record returned is different from what we expect; got:\n%q\nwant:\n%q", got, want) 572 } 573 574 if err := record.ServerTimestamp.CheckValid(); err != nil { 575 t.Fatalf("Invalid tspb.Timestamp object cannot be converted to time.Time object: %v", err) 576 } 577 recordExactTime := record.ServerTimestamp.AsTime() 578 nanoBeforeRecord := recordExactTime.Add(time.Duration(-1) * time.Nanosecond) 579 nanoBeforeRecordTS := tspb.New(nanoBeforeRecord) 580 if err := nanoBeforeRecordTS.CheckValid(); err != nil { 581 t.Fatalf("Invalid time.Time object cannot be converted to tpb.Timestamp: %v", err) 582 } 583 nanoAfterRecord := recordExactTime.Add(time.Nanosecond) 584 nanoAfterRecordTS := tspb.New(nanoAfterRecord) 585 if err := nanoAfterRecordTS.CheckValid(); err != nil { 586 t.Fatalf("Invalid time.Time object cannot be converted to tpb.Timestamp: %v", err) 587 } 588 589 for _, tr := range []struct { 590 desc string 591 startTs *tspb.Timestamp 592 endTs *tspb.Timestamp 593 shouldErr bool 594 recordsExpected int 595 }{ 596 { 597 desc: "record out of time range", 598 startTs: nanoBeforeRecordTS, 599 endTs: record.ServerTimestamp, 600 recordsExpected: 0, 601 }, 602 { 603 desc: "time range invalid", 604 startTs: record.ServerTimestamp, 605 endTs: nanoBeforeRecordTS, 606 shouldErr: true, 607 }, 608 { 609 desc: "record in time range", 610 startTs: record.ServerTimestamp, 611 endTs: nanoAfterRecordTS, 612 recordsExpected: 1, 613 }, 614 } { 615 records, err = ds.FetchResourceUsageRecords(ctx, clientID, tr.startTs, tr.endTs) 616 if tr.shouldErr { 617 if err == nil { 618 t.Errorf("%s: Should have errored when trying to fetch resource-usage data for client as time range is invalid, but didn't error.", tr.desc) 619 } 620 } else { 621 if err != nil { 622 t.Errorf("%s: Unexpected error when trying to fetch resource-usage data for client: %v", tr.desc, err) 623 } 624 if len(records) != tr.recordsExpected { 625 t.Fatalf("%s: Unexpected number of records returned. Want %d, got %v", tr.desc, tr.recordsExpected, len(records)) 626 } 627 } 628 } 629 } 630 631 func clientStoreTestSuite(t *testing.T, env DbTestEnv) { 632 t.Run("ClientStoreTestSuite", func(t *testing.T) { 633 runTestSuite(t, env, map[string]func(*testing.T, db.Store){ 634 "AddClientsTest": addClientsTest, 635 "ClientStoreTest": clientStoreTest, 636 "ListClientsTest": listClientsTest, 637 "StreamClientIdsTest": streamClientIdsTest, 638 "FetchResourceUsageRecordsTest": fetchResourceUsageRecordsTest, 639 }) 640 }) 641 }