github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/servertests/comms_test.go (about) 1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package servertests_test 16 17 import ( 18 "bytes" 19 "context" 20 "crypto" 21 "crypto/ecdsa" 22 "crypto/elliptic" 23 "crypto/rand" 24 "crypto/rsa" 25 "errors" 26 "net" 27 "strings" 28 "testing" 29 "time" 30 31 "github.com/google/fleetspeak/fleetspeak/src/common" 32 "github.com/google/fleetspeak/fleetspeak/src/common/anypbtest" 33 "github.com/google/fleetspeak/fleetspeak/src/server/db" 34 "github.com/google/fleetspeak/fleetspeak/src/server/internal/services" 35 "github.com/google/fleetspeak/fleetspeak/src/server/sertesting" 36 "github.com/google/fleetspeak/fleetspeak/src/server/service" 37 "github.com/google/fleetspeak/fleetspeak/src/server/testserver" 38 "google.golang.org/protobuf/proto" 39 tspb "google.golang.org/protobuf/types/known/timestamppb" 40 41 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 42 ) 43 44 func TestCommsContext(t *testing.T) { 45 fakeTime := sertesting.FakeNow(50) 46 defer fakeTime.Revert() 47 48 ts := testserver.Make(t, "server", "CommsContext", nil) 49 defer ts.S.Stop() 50 ctx := context.Background() 51 52 // Verify that we can add clients using different types of keys. 53 privateKey1, err := rsa.GenerateKey(rand.Reader, 2048) 54 if err != nil { 55 t.Fatal(err) 56 } 57 privateKey2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 58 if err != nil { 59 t.Fatal(err) 60 } 61 62 // For each client/key we go through a basic lifecyle - add the client 63 // to the system, check for messages for the client, etc. 64 for _, tc := range []struct { 65 name string 66 pub crypto.PublicKey 67 streaming bool 68 }{ 69 { 70 name: "rsa", 71 pub: privateKey1.Public()}, 72 { 73 name: "ecdsa", 74 pub: privateKey2.Public()}, 75 { 76 name: "rsa-streaming", 77 pub: privateKey1.Public(), 78 streaming: true}, 79 { 80 name: "ecdsa-streaming", 81 pub: privateKey2.Public(), 82 streaming: true}, 83 } { 84 ci, cd, _, err := ts.CC.InitializeConnection( 85 ctx, 86 &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123}, 87 tc.pub, 88 &fspb.WrappedContactData{}, 89 false) 90 if err != nil { 91 t.Fatal(err) 92 } 93 id, err := common.MakeClientID(tc.pub) 94 if err != nil { 95 t.Fatal(err) 96 } 97 if ci.Addr.Network() != "tcp" || ci.Addr.String() != "127.0.0.1:123" { 98 t.Errorf("%s: InitializeConnection returned ci.Addr of [%s,%v], but expected [tcp,127.0.0.1:123]", tc.name, ci.Addr.Network(), ci.Addr) 99 } 100 if ci.Client.ID != id { 101 t.Errorf("%s: InitializeConnection returned client ID of %v, but expected %v", tc.name, ci.Client.ID, id) 102 } 103 if ci.Client.Key == nil { 104 t.Errorf("%s: InitializeConnection returned empty ci.Client.Key", tc.name) 105 } 106 if ci.ContactID == "" { 107 t.Errorf("%s: InitializeConnection returned empty ci.ContactID", tc.name) 108 } 109 if ci.NonceSent == 0 { 110 t.Errorf("%s: InitializeConnection returned 0 NonceSent", tc.name) 111 } 112 if len(cd.Messages) != 0 { 113 t.Fatalf("%s: Expected no messages, got: %v", tc.name, cd.Messages) 114 } 115 116 // If a client does provide messages, they should end up in the datastore. 117 fakeTime.SetSeconds(1234) 118 cd = &fspb.ContactData{ 119 SequencingNonce: 5, 120 Messages: []*fspb.Message{ 121 { 122 Source: &fspb.Address{ 123 ClientId: id.Bytes(), 124 ServiceName: "TestService", 125 }, 126 Destination: &fspb.Address{ 127 ServiceName: "TestService", 128 }, 129 SourceMessageId: []byte("AAABBBCCC"), 130 MessageType: "TestMessage", 131 }, 132 }, 133 } 134 bcd, err := proto.Marshal(cd) 135 if err != nil { 136 t.Fatalf("%s: Unable to marshal contact data: %v", tc.name, err) 137 } 138 if tc.streaming { 139 if err := ts.CC.HandleMessagesFromClient( 140 ctx, 141 ci, 142 &fspb.WrappedContactData{ContactData: bcd}); err != nil { 143 t.Fatal(err) 144 } 145 cd, _, err := ts.CC.GetMessagesForClient(ctx, ci) 146 if err != nil { 147 t.Fatal(err) 148 } 149 if cd != nil { 150 t.Errorf("%s: Expected nil ContactData, got: %v", tc.name, cd) 151 } 152 } else { 153 if ci, cd, _, err = ts.CC.InitializeConnection( 154 ctx, 155 &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123}, 156 tc.pub, 157 &fspb.WrappedContactData{ContactData: bcd}, 158 false); err != nil { 159 t.Fatal(err) 160 } 161 } 162 fakeTime.SetSeconds(3000) 163 164 mid := common.MakeMessageID( 165 &fspb.Address{ 166 ClientId: id.Bytes(), 167 ServiceName: "TestService", 168 }, 169 []byte("AAABBBCCC"), 170 ) 171 msgs, err := ts.DS.GetMessages(ctx, []common.MessageID{mid}, false) 172 173 if err != nil { 174 t.Fatal(err) 175 } 176 if len(msgs) != 1 { 177 t.Fatalf("Expected 1 message, got: %v", msgs) 178 } 179 want := &fspb.Message{ 180 MessageId: mid.Bytes(), 181 Source: &fspb.Address{ 182 ClientId: id.Bytes(), 183 ServiceName: "TestService", 184 }, 185 Destination: &fspb.Address{ 186 ServiceName: "TestService", 187 }, 188 SourceMessageId: []byte("AAABBBCCC"), 189 MessageType: "TestMessage", 190 CreationTime: &tspb.Timestamp{Seconds: 1234}, 191 } 192 msgs[0].Result = nil 193 if !proto.Equal(msgs[0], want) { 194 t.Errorf("%s: InitializeConnection(%v)=%v, but want %v", tc.name, id, msgs[0], want) 195 } 196 } 197 } 198 199 func TestBlacklist(t *testing.T) { 200 ts := testserver.Make(t, "server", "Blacklist", nil) 201 defer ts.S.Stop() 202 ctx := context.Background() 203 204 k, err := ts.AddClient() 205 if err != nil { 206 t.Fatal(err) 207 } 208 id, err := common.MakeClientID(k) 209 if err != nil { 210 t.Fatal(err) 211 } 212 213 // Put a message in the database that would otherwise be ready for delivery. 214 mid, err := common.RandomMessageID() 215 if err != nil { 216 t.Fatalf("Unable to create message id: %v", err) 217 } 218 if err := ts.DS.StoreMessages(ctx, []*fspb.Message{ 219 { 220 MessageId: mid.Bytes(), 221 Source: &fspb.Address{ 222 ServiceName: "testService", 223 }, 224 Destination: &fspb.Address{ 225 ServiceName: "testService", 226 ClientId: id.Bytes(), 227 }, 228 MessageType: "TestMessage", 229 CreationTime: db.NowProto(), 230 }}, ""); err != nil { 231 t.Fatalf("Unable to store message: %v", err) 232 } 233 234 // Blacklist the client 235 if err := ts.DS.BlacklistClient(ctx, id); err != nil { 236 t.Fatalf("BlacklistClient returned error: %v", err) 237 } 238 239 msgs, err := ts.SimulateContactFromClient(ctx, k, nil) 240 if err != nil { 241 t.Error(err) 242 } 243 244 if len(msgs) != 1 { 245 t.Fatalf("Expected 1 message, got: %+v", msgs) 246 } 247 msg := msgs[0] 248 249 if msg.MessageType != "RekeyRequest" { 250 t.Errorf("Expected RekeyRequest, got: %+v", msg) 251 } 252 253 // Verify that the RekeyRequest message is in the database. 254 mid, err = common.BytesToMessageID(msg.MessageId) 255 if err != nil { 256 t.Fatalf("Unable to parse RekeyRequest message id: %v", err) 257 } 258 259 msgs, err = ts.DS.GetMessages(ctx, []common.MessageID{mid}, true) 260 if err != nil { 261 t.Fatalf("Error reading rekey message from datastore: %v", err) 262 } 263 if len(msgs) != 1 { 264 t.Fatalf("GetMessages([%v]) returned %d messages, expected 1.", mid, len(msgs)) 265 } 266 if !bytes.Equal(msgs[0].MessageId, msg.MessageId) || msgs[0].MessageType != "RekeyRequest" { 267 t.Errorf("GetMessage([%v]) did not return expected RekeyRequest, want: %+v got: %+v", mid, msg, msgs[0]) 268 } 269 } 270 271 // blocklistService is a Fleetspeak service.Service that counts blocklisted 272 // and non-blocklisted messages. 273 type blocklistService struct { 274 blocklistedCount uint 275 nonBlocklistedCount uint 276 } 277 278 func (s *blocklistService) Start(sctx service.Context) error { return nil } 279 func (s *blocklistService) ProcessMessage(ctx context.Context, m *fspb.Message) error { 280 if m.IsBlocklistedSource { 281 s.blocklistedCount++ 282 } else { 283 s.nonBlocklistedCount++ 284 } 285 return nil 286 } 287 func (s *blocklistService) Stop() error { return nil } 288 289 func TestStoredMessagesFromBlocklistedClient(t *testing.T) { 290 fin := sertesting.SetServerRetryTime(func(_ uint32) time.Time { 291 return db.Now().Add(time.Second) 292 }) 293 defer fin() 294 295 ctx := context.Background() 296 testService := &blocklistService{} 297 298 ts := testserver.MakeWithService(t, "server", "Blocklist", testService) 299 defer ts.S.Stop() 300 301 k, err := ts.AddClient() 302 if err != nil { 303 t.Fatal(err) 304 } 305 id, err := common.MakeClientID(k) 306 if err != nil { 307 t.Fatal(err) 308 } 309 310 // Blacklist the client 311 if err := ts.DS.BlacklistClient(ctx, id); err != nil { 312 t.Fatalf("BlacklistClient returned error: %v", err) 313 } 314 315 // Put a message in the database that would otherwise be ready for delivery. 316 mID, err := common.RandomMessageID() 317 if err != nil { 318 t.Fatalf("Unable to create message id: %v", err) 319 } 320 clientMessage := &fspb.Message{ 321 MessageId: mID.Bytes(), 322 SourceMessageId: []byte("AAABBBCCC"), 323 Source: &fspb.Address{ 324 ServiceName: "TestService", 325 ClientId: id.Bytes(), 326 }, 327 Destination: &fspb.Address{ 328 ServiceName: "TestService", 329 }, 330 MessageType: "TestMessage", 331 CreationTime: db.NowProto(), 332 } 333 334 if err := ts.DS.StoreMessages(ctx, []*fspb.Message{clientMessage}, ""); err != nil { 335 t.Fatalf("Unable to store message: %v", err) 336 } 337 338 tctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 339 defer cancel() 340 for { 341 msgs, err := ts.DS.GetMessages(tctx, []common.MessageID{mID}, true) 342 if err != nil { 343 t.Logf("GetMessages failed: %v", err) 344 goto Skip 345 } 346 if len(msgs) != 1 { 347 t.Fatalf("Expected 1 message, got: %v", msgs) 348 } 349 350 t.Logf("message %v", msgs[0]) 351 if msgs[0].Result != nil { 352 break 353 } 354 Skip: 355 if tctx.Err() != nil { 356 t.Fatal(tctx.Err()) 357 } 358 time.Sleep(100 * time.Millisecond) 359 } 360 361 messageResult, err := ts.DS.GetMessageResult(ctx, mID) 362 if err != nil { 363 t.Fatalf("GetMessageResult(%v) failed unexpectedly: %v", mID, err) 364 } 365 if messageResult == nil { 366 t.Fatalf("GetMessageResult(%v) returned empty result, want non-empty.", mID) 367 } 368 369 if testService.nonBlocklistedCount != 0 { 370 t.Errorf("Got %d non-blocklisted messages, want 0", testService.nonBlocklistedCount) 371 } 372 373 if testService.blocklistedCount != 1 { 374 t.Errorf("Got %d blocklisted messages, want 1", testService.blocklistedCount) 375 } 376 } 377 378 func TestDie(t *testing.T) { 379 ts := testserver.Make(t, "server", "Die", nil) 380 defer ts.S.Stop() 381 ctx := context.Background() 382 383 k, err := ts.AddClient() 384 if err != nil { 385 t.Fatal(err) 386 } 387 id, err := common.MakeClientID(k) 388 if err != nil { 389 t.Fatal(err) 390 } 391 392 // Create a Die message and a Foo message for the client 393 394 midDie, err := common.RandomMessageID() 395 if err != nil { 396 t.Fatal(err) 397 } 398 midFoo, err := common.RandomMessageID() 399 if err != nil { 400 t.Fatal(err) 401 } 402 err = ts.DS.StoreMessages(ctx, []*fspb.Message{ 403 { 404 MessageId: midDie.Bytes(), 405 Source: &fspb.Address{ 406 ServiceName: "system", 407 }, 408 Destination: &fspb.Address{ 409 ServiceName: "system", 410 ClientId: id.Bytes(), 411 }, 412 MessageType: "Die", 413 CreationTime: db.NowProto(), 414 }, 415 { 416 MessageId: midFoo.Bytes(), 417 Source: &fspb.Address{ 418 ServiceName: "foo", 419 }, 420 Destination: &fspb.Address{ 421 ServiceName: "foo", 422 ClientId: id.Bytes(), 423 }, 424 MessageType: "Foo", 425 CreationTime: db.NowProto(), 426 }, 427 }, "") 428 if err != nil { 429 t.Fatalf("Unable to store message: %v", err) 430 } 431 432 // Simulate contact from client 433 434 cd := fspb.ContactData{AllowedMessages: map[string]uint64{"foo": 20, "system": 20}} 435 cdb, err := proto.Marshal(&cd) 436 if err != nil { 437 t.Error(err) 438 } 439 ci, rcd, _, err := ts.CC.InitializeConnection( 440 ctx, 441 &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123}, 442 k, 443 &fspb.WrappedContactData{ContactData: cdb}, 444 false) 445 if err != nil { 446 t.Error(err) 447 } 448 msgs := rcd.Messages 449 450 if len(msgs) != 2 { 451 t.Fatalf("Expected 2 messages, got: %+v", msgs) 452 } 453 454 // Check tokens 455 // The Die message should not have consumed any token. 456 457 if ci.MessageTokens()["foo"] != 19 { 458 t.Fatalf("Service foo should have 19 tokens left.") 459 } 460 461 if ci.MessageTokens()["system"] != 20 { 462 t.Fatalf("Service system should have all 20 tokens left.") 463 } 464 465 // The Die message should be acked automatically 466 467 m := ts.GetMessage(ctx, midDie) 468 if m.Result == nil || m.Result.Failed { 469 t.Error("Expected result of Die message to be success.") 470 } 471 472 // The Foo message should not be acked 473 474 m = ts.GetMessage(ctx, midFoo) 475 if m.Result != nil { 476 t.Error("Expected no result for Foo message.") 477 } 478 479 // The client sends a MessageAck for the Foo message 480 m = &fspb.Message{ 481 Source: &fspb.Address{ 482 ClientId: id.Bytes(), 483 ServiceName: "system", 484 }, 485 Destination: &fspb.Address{ 486 ServiceName: "system", 487 }, 488 SourceMessageId: []byte("1"), 489 MessageType: "MessageAck", 490 Data: anypbtest.New(t, &fspb.MessageAckData{ 491 MessageIds: [][]byte{midFoo.Bytes()}, 492 }), 493 } 494 m.MessageId = common.MakeMessageID(m.Source, m.SourceMessageId).Bytes() 495 496 err = ts.ProcessMessageFromClient(k, m) 497 if err != nil { 498 t.Fatal(err) 499 } 500 501 // Both the Foo and Die messages should be acked. 502 503 m = ts.GetMessage(ctx, midDie) 504 if m.Result == nil || m.Result.Failed { 505 t.Error("Expected result of Die message to be success.") 506 } 507 m = ts.GetMessage(ctx, midFoo) 508 if m.Result == nil || m.Result.Failed { 509 t.Error("Expected result of Foo message to be success.") 510 } 511 } 512 513 // errorService is a Fleetspeak service.Service that returns a specified 514 // error every time Service.ProcessMessage() is called. 515 type errorService struct { 516 err error 517 } 518 519 func (s *errorService) Start(sctx service.Context) error { return nil } 520 func (s *errorService) ProcessMessage(ctx context.Context, m *fspb.Message) error { return s.err } 521 func (s *errorService) Stop() error { return nil } 522 523 func TestServiceError(t *testing.T) { 524 ctx := context.Background() 525 testService := &errorService{errors.New(strings.Repeat("a", services.MaxServiceFailureReasonLength+1))} 526 serverWrapper := testserver.MakeWithService(t, "server", "ServiceError", testService) 527 defer serverWrapper.S.Stop() 528 529 clientPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 530 if err != nil { 531 t.Fatal(err) 532 } 533 clientPublicKey := clientPrivateKey.Public() 534 535 clientID, err := common.MakeClientID(clientPublicKey) 536 if err != nil { 537 t.Fatal(err) 538 } 539 clientMessage := &fspb.Message{ 540 Source: &fspb.Address{ 541 ClientId: clientID.Bytes(), 542 ServiceName: "TestService", 543 }, 544 Destination: &fspb.Address{ 545 ServiceName: "TestService", 546 }, 547 SourceMessageId: []byte("AAABBBCCC"), 548 MessageType: "TestMessage", 549 } 550 contactData := &fspb.ContactData{ 551 SequencingNonce: 5, 552 Messages: []*fspb.Message{clientMessage}, 553 } 554 serializedContactData, err := proto.Marshal(contactData) 555 if err != nil { 556 t.Fatalf("Unable to marshal contact data: %v", err) 557 } 558 559 if _, _, _, err = serverWrapper.CC.InitializeConnection( 560 ctx, 561 &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123}, 562 clientPublicKey, 563 &fspb.WrappedContactData{ContactData: serializedContactData}, 564 false); err != nil { 565 t.Fatalf("InitializeConnection() failed: %v", err) 566 } 567 568 messageID := common.MakeMessageID(clientMessage.Source, clientMessage.SourceMessageId) 569 messageResult, err := serverWrapper.DS.GetMessageResult(ctx, messageID) 570 if err != nil { 571 t.Fatalf("Failed to get message result: %v", err) 572 } 573 574 expectedFailedReason := strings.Repeat("a", services.MaxServiceFailureReasonLength-3) + "..." 575 if messageResult.FailedReason != expectedFailedReason { 576 t.Errorf("Unexpected failure reason: got [%v], want [%v]", messageResult.FailedReason, expectedFailedReason) 577 } 578 }