github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/servertests/comms_test.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package servertests_test
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto"
    21  	"crypto/ecdsa"
    22  	"crypto/elliptic"
    23  	"crypto/rand"
    24  	"crypto/rsa"
    25  	"errors"
    26  	"net"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/google/fleetspeak/fleetspeak/src/common"
    32  	"github.com/google/fleetspeak/fleetspeak/src/common/anypbtest"
    33  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    34  	"github.com/google/fleetspeak/fleetspeak/src/server/internal/services"
    35  	"github.com/google/fleetspeak/fleetspeak/src/server/sertesting"
    36  	"github.com/google/fleetspeak/fleetspeak/src/server/service"
    37  	"github.com/google/fleetspeak/fleetspeak/src/server/testserver"
    38  	"google.golang.org/protobuf/proto"
    39  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    40  
    41  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    42  )
    43  
    44  func TestCommsContext(t *testing.T) {
    45  	fakeTime := sertesting.FakeNow(50)
    46  	defer fakeTime.Revert()
    47  
    48  	ts := testserver.Make(t, "server", "CommsContext", nil)
    49  	defer ts.S.Stop()
    50  	ctx := context.Background()
    51  
    52  	// Verify that we can add clients using different types of keys.
    53  	privateKey1, err := rsa.GenerateKey(rand.Reader, 2048)
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  	privateKey2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  
    62  	// For each client/key we go through a basic lifecyle - add the client
    63  	// to the system, check for messages for the client, etc.
    64  	for _, tc := range []struct {
    65  		name      string
    66  		pub       crypto.PublicKey
    67  		streaming bool
    68  	}{
    69  		{
    70  			name: "rsa",
    71  			pub:  privateKey1.Public()},
    72  		{
    73  			name: "ecdsa",
    74  			pub:  privateKey2.Public()},
    75  		{
    76  			name:      "rsa-streaming",
    77  			pub:       privateKey1.Public(),
    78  			streaming: true},
    79  		{
    80  			name:      "ecdsa-streaming",
    81  			pub:       privateKey2.Public(),
    82  			streaming: true},
    83  	} {
    84  		ci, cd, _, err := ts.CC.InitializeConnection(
    85  			ctx,
    86  			&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123},
    87  			tc.pub,
    88  			&fspb.WrappedContactData{},
    89  			false)
    90  		if err != nil {
    91  			t.Fatal(err)
    92  		}
    93  		id, err := common.MakeClientID(tc.pub)
    94  		if err != nil {
    95  			t.Fatal(err)
    96  		}
    97  		if ci.Addr.Network() != "tcp" || ci.Addr.String() != "127.0.0.1:123" {
    98  			t.Errorf("%s: InitializeConnection returned ci.Addr of [%s,%v], but expected [tcp,127.0.0.1:123]", tc.name, ci.Addr.Network(), ci.Addr)
    99  		}
   100  		if ci.Client.ID != id {
   101  			t.Errorf("%s: InitializeConnection returned client ID of %v, but expected %v", tc.name, ci.Client.ID, id)
   102  		}
   103  		if ci.Client.Key == nil {
   104  			t.Errorf("%s: InitializeConnection returned empty ci.Client.Key", tc.name)
   105  		}
   106  		if ci.ContactID == "" {
   107  			t.Errorf("%s: InitializeConnection returned empty ci.ContactID", tc.name)
   108  		}
   109  		if ci.NonceSent == 0 {
   110  			t.Errorf("%s: InitializeConnection returned 0 NonceSent", tc.name)
   111  		}
   112  		if len(cd.Messages) != 0 {
   113  			t.Fatalf("%s: Expected no messages, got: %v", tc.name, cd.Messages)
   114  		}
   115  
   116  		// If a client does provide messages, they should end up in the datastore.
   117  		fakeTime.SetSeconds(1234)
   118  		cd = &fspb.ContactData{
   119  			SequencingNonce: 5,
   120  			Messages: []*fspb.Message{
   121  				{
   122  					Source: &fspb.Address{
   123  						ClientId:    id.Bytes(),
   124  						ServiceName: "TestService",
   125  					},
   126  					Destination: &fspb.Address{
   127  						ServiceName: "TestService",
   128  					},
   129  					SourceMessageId: []byte("AAABBBCCC"),
   130  					MessageType:     "TestMessage",
   131  				},
   132  			},
   133  		}
   134  		bcd, err := proto.Marshal(cd)
   135  		if err != nil {
   136  			t.Fatalf("%s: Unable to marshal contact data: %v", tc.name, err)
   137  		}
   138  		if tc.streaming {
   139  			if err := ts.CC.HandleMessagesFromClient(
   140  				ctx,
   141  				ci,
   142  				&fspb.WrappedContactData{ContactData: bcd}); err != nil {
   143  				t.Fatal(err)
   144  			}
   145  			cd, _, err := ts.CC.GetMessagesForClient(ctx, ci)
   146  			if err != nil {
   147  				t.Fatal(err)
   148  			}
   149  			if cd != nil {
   150  				t.Errorf("%s: Expected nil ContactData, got: %v", tc.name, cd)
   151  			}
   152  		} else {
   153  			if ci, cd, _, err = ts.CC.InitializeConnection(
   154  				ctx,
   155  				&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123},
   156  				tc.pub,
   157  				&fspb.WrappedContactData{ContactData: bcd},
   158  				false); err != nil {
   159  				t.Fatal(err)
   160  			}
   161  		}
   162  		fakeTime.SetSeconds(3000)
   163  
   164  		mid := common.MakeMessageID(
   165  			&fspb.Address{
   166  				ClientId:    id.Bytes(),
   167  				ServiceName: "TestService",
   168  			},
   169  			[]byte("AAABBBCCC"),
   170  		)
   171  		msgs, err := ts.DS.GetMessages(ctx, []common.MessageID{mid}, false)
   172  
   173  		if err != nil {
   174  			t.Fatal(err)
   175  		}
   176  		if len(msgs) != 1 {
   177  			t.Fatalf("Expected 1 message, got: %v", msgs)
   178  		}
   179  		want := &fspb.Message{
   180  			MessageId: mid.Bytes(),
   181  			Source: &fspb.Address{
   182  				ClientId:    id.Bytes(),
   183  				ServiceName: "TestService",
   184  			},
   185  			Destination: &fspb.Address{
   186  				ServiceName: "TestService",
   187  			},
   188  			SourceMessageId: []byte("AAABBBCCC"),
   189  			MessageType:     "TestMessage",
   190  			CreationTime:    &tspb.Timestamp{Seconds: 1234},
   191  		}
   192  		msgs[0].Result = nil
   193  		if !proto.Equal(msgs[0], want) {
   194  			t.Errorf("%s: InitializeConnection(%v)=%v, but want %v", tc.name, id, msgs[0], want)
   195  		}
   196  	}
   197  }
   198  
   199  func TestBlacklist(t *testing.T) {
   200  	ts := testserver.Make(t, "server", "Blacklist", nil)
   201  	defer ts.S.Stop()
   202  	ctx := context.Background()
   203  
   204  	k, err := ts.AddClient()
   205  	if err != nil {
   206  		t.Fatal(err)
   207  	}
   208  	id, err := common.MakeClientID(k)
   209  	if err != nil {
   210  		t.Fatal(err)
   211  	}
   212  
   213  	// Put a message in the database that would otherwise be ready for delivery.
   214  	mid, err := common.RandomMessageID()
   215  	if err != nil {
   216  		t.Fatalf("Unable to create message id: %v", err)
   217  	}
   218  	if err := ts.DS.StoreMessages(ctx, []*fspb.Message{
   219  		{
   220  			MessageId: mid.Bytes(),
   221  			Source: &fspb.Address{
   222  				ServiceName: "testService",
   223  			},
   224  			Destination: &fspb.Address{
   225  				ServiceName: "testService",
   226  				ClientId:    id.Bytes(),
   227  			},
   228  			MessageType:  "TestMessage",
   229  			CreationTime: db.NowProto(),
   230  		}}, ""); err != nil {
   231  		t.Fatalf("Unable to store message: %v", err)
   232  	}
   233  
   234  	// Blacklist the client
   235  	if err := ts.DS.BlacklistClient(ctx, id); err != nil {
   236  		t.Fatalf("BlacklistClient returned error: %v", err)
   237  	}
   238  
   239  	msgs, err := ts.SimulateContactFromClient(ctx, k, nil)
   240  	if err != nil {
   241  		t.Error(err)
   242  	}
   243  
   244  	if len(msgs) != 1 {
   245  		t.Fatalf("Expected 1 message, got: %+v", msgs)
   246  	}
   247  	msg := msgs[0]
   248  
   249  	if msg.MessageType != "RekeyRequest" {
   250  		t.Errorf("Expected RekeyRequest, got: %+v", msg)
   251  	}
   252  
   253  	// Verify that the RekeyRequest message is in the database.
   254  	mid, err = common.BytesToMessageID(msg.MessageId)
   255  	if err != nil {
   256  		t.Fatalf("Unable to parse RekeyRequest message id: %v", err)
   257  	}
   258  
   259  	msgs, err = ts.DS.GetMessages(ctx, []common.MessageID{mid}, true)
   260  	if err != nil {
   261  		t.Fatalf("Error reading rekey message from datastore: %v", err)
   262  	}
   263  	if len(msgs) != 1 {
   264  		t.Fatalf("GetMessages([%v]) returned %d messages, expected 1.", mid, len(msgs))
   265  	}
   266  	if !bytes.Equal(msgs[0].MessageId, msg.MessageId) || msgs[0].MessageType != "RekeyRequest" {
   267  		t.Errorf("GetMessage([%v]) did not return expected RekeyRequest, want: %+v got: %+v", mid, msg, msgs[0])
   268  	}
   269  }
   270  
   271  // blocklistService is a Fleetspeak service.Service that counts blocklisted
   272  // and non-blocklisted messages.
   273  type blocklistService struct {
   274  	blocklistedCount    uint
   275  	nonBlocklistedCount uint
   276  }
   277  
   278  func (s *blocklistService) Start(sctx service.Context) error { return nil }
   279  func (s *blocklistService) ProcessMessage(ctx context.Context, m *fspb.Message) error {
   280  	if m.IsBlocklistedSource {
   281  		s.blocklistedCount++
   282  	} else {
   283  		s.nonBlocklistedCount++
   284  	}
   285  	return nil
   286  }
   287  func (s *blocklistService) Stop() error { return nil }
   288  
   289  func TestStoredMessagesFromBlocklistedClient(t *testing.T) {
   290  	fin := sertesting.SetServerRetryTime(func(_ uint32) time.Time {
   291  		return db.Now().Add(time.Second)
   292  	})
   293  	defer fin()
   294  
   295  	ctx := context.Background()
   296  	testService := &blocklistService{}
   297  
   298  	ts := testserver.MakeWithService(t, "server", "Blocklist", testService)
   299  	defer ts.S.Stop()
   300  
   301  	k, err := ts.AddClient()
   302  	if err != nil {
   303  		t.Fatal(err)
   304  	}
   305  	id, err := common.MakeClientID(k)
   306  	if err != nil {
   307  		t.Fatal(err)
   308  	}
   309  
   310  	// Blacklist the client
   311  	if err := ts.DS.BlacklistClient(ctx, id); err != nil {
   312  		t.Fatalf("BlacklistClient returned error: %v", err)
   313  	}
   314  
   315  	// Put a message in the database that would otherwise be ready for delivery.
   316  	mID, err := common.RandomMessageID()
   317  	if err != nil {
   318  		t.Fatalf("Unable to create message id: %v", err)
   319  	}
   320  	clientMessage := &fspb.Message{
   321  		MessageId:       mID.Bytes(),
   322  		SourceMessageId: []byte("AAABBBCCC"),
   323  		Source: &fspb.Address{
   324  			ServiceName: "TestService",
   325  			ClientId:    id.Bytes(),
   326  		},
   327  		Destination: &fspb.Address{
   328  			ServiceName: "TestService",
   329  		},
   330  		MessageType:  "TestMessage",
   331  		CreationTime: db.NowProto(),
   332  	}
   333  
   334  	if err := ts.DS.StoreMessages(ctx, []*fspb.Message{clientMessage}, ""); err != nil {
   335  		t.Fatalf("Unable to store message: %v", err)
   336  	}
   337  
   338  	tctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
   339  	defer cancel()
   340  	for {
   341  		msgs, err := ts.DS.GetMessages(tctx, []common.MessageID{mID}, true)
   342  		if err != nil {
   343  			t.Logf("GetMessages failed: %v", err)
   344  			goto Skip
   345  		}
   346  		if len(msgs) != 1 {
   347  			t.Fatalf("Expected 1 message, got: %v", msgs)
   348  		}
   349  
   350  		t.Logf("message %v", msgs[0])
   351  		if msgs[0].Result != nil {
   352  			break
   353  		}
   354  	Skip:
   355  		if tctx.Err() != nil {
   356  			t.Fatal(tctx.Err())
   357  		}
   358  		time.Sleep(100 * time.Millisecond)
   359  	}
   360  
   361  	messageResult, err := ts.DS.GetMessageResult(ctx, mID)
   362  	if err != nil {
   363  		t.Fatalf("GetMessageResult(%v) failed unexpectedly: %v", mID, err)
   364  	}
   365  	if messageResult == nil {
   366  		t.Fatalf("GetMessageResult(%v) returned empty result, want non-empty.", mID)
   367  	}
   368  
   369  	if testService.nonBlocklistedCount != 0 {
   370  		t.Errorf("Got %d non-blocklisted messages, want 0", testService.nonBlocklistedCount)
   371  	}
   372  
   373  	if testService.blocklistedCount != 1 {
   374  		t.Errorf("Got %d blocklisted messages, want 1", testService.blocklistedCount)
   375  	}
   376  }
   377  
   378  func TestDie(t *testing.T) {
   379  	ts := testserver.Make(t, "server", "Die", nil)
   380  	defer ts.S.Stop()
   381  	ctx := context.Background()
   382  
   383  	k, err := ts.AddClient()
   384  	if err != nil {
   385  		t.Fatal(err)
   386  	}
   387  	id, err := common.MakeClientID(k)
   388  	if err != nil {
   389  		t.Fatal(err)
   390  	}
   391  
   392  	// Create a Die message and a Foo message for the client
   393  
   394  	midDie, err := common.RandomMessageID()
   395  	if err != nil {
   396  		t.Fatal(err)
   397  	}
   398  	midFoo, err := common.RandomMessageID()
   399  	if err != nil {
   400  		t.Fatal(err)
   401  	}
   402  	err = ts.DS.StoreMessages(ctx, []*fspb.Message{
   403  		{
   404  			MessageId: midDie.Bytes(),
   405  			Source: &fspb.Address{
   406  				ServiceName: "system",
   407  			},
   408  			Destination: &fspb.Address{
   409  				ServiceName: "system",
   410  				ClientId:    id.Bytes(),
   411  			},
   412  			MessageType:  "Die",
   413  			CreationTime: db.NowProto(),
   414  		},
   415  		{
   416  			MessageId: midFoo.Bytes(),
   417  			Source: &fspb.Address{
   418  				ServiceName: "foo",
   419  			},
   420  			Destination: &fspb.Address{
   421  				ServiceName: "foo",
   422  				ClientId:    id.Bytes(),
   423  			},
   424  			MessageType:  "Foo",
   425  			CreationTime: db.NowProto(),
   426  		},
   427  	}, "")
   428  	if err != nil {
   429  		t.Fatalf("Unable to store message: %v", err)
   430  	}
   431  
   432  	// Simulate contact from client
   433  
   434  	cd := fspb.ContactData{AllowedMessages: map[string]uint64{"foo": 20, "system": 20}}
   435  	cdb, err := proto.Marshal(&cd)
   436  	if err != nil {
   437  		t.Error(err)
   438  	}
   439  	ci, rcd, _, err := ts.CC.InitializeConnection(
   440  		ctx,
   441  		&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123},
   442  		k,
   443  		&fspb.WrappedContactData{ContactData: cdb},
   444  		false)
   445  	if err != nil {
   446  		t.Error(err)
   447  	}
   448  	msgs := rcd.Messages
   449  
   450  	if len(msgs) != 2 {
   451  		t.Fatalf("Expected 2 messages, got: %+v", msgs)
   452  	}
   453  
   454  	// Check tokens
   455  	// The Die message should not have consumed any token.
   456  
   457  	if ci.MessageTokens()["foo"] != 19 {
   458  		t.Fatalf("Service foo should have 19 tokens left.")
   459  	}
   460  
   461  	if ci.MessageTokens()["system"] != 20 {
   462  		t.Fatalf("Service system should have all 20 tokens left.")
   463  	}
   464  
   465  	// The Die message should be acked automatically
   466  
   467  	m := ts.GetMessage(ctx, midDie)
   468  	if m.Result == nil || m.Result.Failed {
   469  		t.Error("Expected result of Die message to be success.")
   470  	}
   471  
   472  	// The Foo message should not be acked
   473  
   474  	m = ts.GetMessage(ctx, midFoo)
   475  	if m.Result != nil {
   476  		t.Error("Expected no result for Foo message.")
   477  	}
   478  
   479  	// The client sends a MessageAck for the Foo message
   480  	m = &fspb.Message{
   481  		Source: &fspb.Address{
   482  			ClientId:    id.Bytes(),
   483  			ServiceName: "system",
   484  		},
   485  		Destination: &fspb.Address{
   486  			ServiceName: "system",
   487  		},
   488  		SourceMessageId: []byte("1"),
   489  		MessageType:     "MessageAck",
   490  		Data: anypbtest.New(t, &fspb.MessageAckData{
   491  			MessageIds: [][]byte{midFoo.Bytes()},
   492  		}),
   493  	}
   494  	m.MessageId = common.MakeMessageID(m.Source, m.SourceMessageId).Bytes()
   495  
   496  	err = ts.ProcessMessageFromClient(k, m)
   497  	if err != nil {
   498  		t.Fatal(err)
   499  	}
   500  
   501  	// Both the Foo and Die messages should be acked.
   502  
   503  	m = ts.GetMessage(ctx, midDie)
   504  	if m.Result == nil || m.Result.Failed {
   505  		t.Error("Expected result of Die message to be success.")
   506  	}
   507  	m = ts.GetMessage(ctx, midFoo)
   508  	if m.Result == nil || m.Result.Failed {
   509  		t.Error("Expected result of Foo message to be success.")
   510  	}
   511  }
   512  
   513  // errorService is a Fleetspeak service.Service that returns a specified
   514  // error every time Service.ProcessMessage() is called.
   515  type errorService struct {
   516  	err error
   517  }
   518  
   519  func (s *errorService) Start(sctx service.Context) error                          { return nil }
   520  func (s *errorService) ProcessMessage(ctx context.Context, m *fspb.Message) error { return s.err }
   521  func (s *errorService) Stop() error                                               { return nil }
   522  
   523  func TestServiceError(t *testing.T) {
   524  	ctx := context.Background()
   525  	testService := &errorService{errors.New(strings.Repeat("a", services.MaxServiceFailureReasonLength+1))}
   526  	serverWrapper := testserver.MakeWithService(t, "server", "ServiceError", testService)
   527  	defer serverWrapper.S.Stop()
   528  
   529  	clientPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   530  	if err != nil {
   531  		t.Fatal(err)
   532  	}
   533  	clientPublicKey := clientPrivateKey.Public()
   534  
   535  	clientID, err := common.MakeClientID(clientPublicKey)
   536  	if err != nil {
   537  		t.Fatal(err)
   538  	}
   539  	clientMessage := &fspb.Message{
   540  		Source: &fspb.Address{
   541  			ClientId:    clientID.Bytes(),
   542  			ServiceName: "TestService",
   543  		},
   544  		Destination: &fspb.Address{
   545  			ServiceName: "TestService",
   546  		},
   547  		SourceMessageId: []byte("AAABBBCCC"),
   548  		MessageType:     "TestMessage",
   549  	}
   550  	contactData := &fspb.ContactData{
   551  		SequencingNonce: 5,
   552  		Messages:        []*fspb.Message{clientMessage},
   553  	}
   554  	serializedContactData, err := proto.Marshal(contactData)
   555  	if err != nil {
   556  		t.Fatalf("Unable to marshal contact data: %v", err)
   557  	}
   558  
   559  	if _, _, _, err = serverWrapper.CC.InitializeConnection(
   560  		ctx,
   561  		&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 123},
   562  		clientPublicKey,
   563  		&fspb.WrappedContactData{ContactData: serializedContactData},
   564  		false); err != nil {
   565  		t.Fatalf("InitializeConnection() failed: %v", err)
   566  	}
   567  
   568  	messageID := common.MakeMessageID(clientMessage.Source, clientMessage.SourceMessageId)
   569  	messageResult, err := serverWrapper.DS.GetMessageResult(ctx, messageID)
   570  	if err != nil {
   571  		t.Fatalf("Failed to get message result: %v", err)
   572  	}
   573  
   574  	expectedFailedReason := strings.Repeat("a", services.MaxServiceFailureReasonLength-3) + "..."
   575  	if messageResult.FailedReason != expectedFailedReason {
   576  		t.Errorf("Unexpected failure reason: got [%v], want [%v]", messageResult.FailedReason, expectedFailedReason)
   577  	}
   578  }