github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/servertests/admin_test.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package servertests_test
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"reflect"
    21  	"sort"
    22  	"strings"
    23  	"testing"
    24  
    25  	"google.golang.org/grpc"
    26  
    27  	"google.golang.org/protobuf/proto"
    28  
    29  	"github.com/google/fleetspeak/fleetspeak/src/common"
    30  	"github.com/google/fleetspeak/fleetspeak/src/common/anypbtest"
    31  	"github.com/google/fleetspeak/fleetspeak/src/server/admin"
    32  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    33  	"github.com/google/fleetspeak/fleetspeak/src/server/ids"
    34  	"github.com/google/fleetspeak/fleetspeak/src/server/testserver"
    35  
    36  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    37  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    38  	anypb "google.golang.org/protobuf/types/known/anypb"
    39  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    40  )
    41  
    42  func TestBroadcastsAPI(t *testing.T) {
    43  	ctx := context.Background()
    44  
    45  	ts := testserver.Make(t, "server", "AdminServer", nil)
    46  	defer ts.S.Stop()
    47  
    48  	as := admin.NewServer(ts.DS, nil)
    49  
    50  	bid, err := ids.BytesToBroadcastID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  
    55  	br := spb.Broadcast{
    56  		BroadcastId: bid.Bytes(),
    57  		Source: &fspb.Address{
    58  			ServiceName: "TestService",
    59  		},
    60  		MessageType: "TestMessage",
    61  	}
    62  
    63  	if _, err := as.CreateBroadcast(ctx, &spb.CreateBroadcastRequest{Broadcast: &br, Limit: 100}); err != nil {
    64  		t.Errorf("CreateBroadcast returned error: %v", err)
    65  	}
    66  
    67  	res, err := as.ListActiveBroadcasts(ctx, &spb.ListActiveBroadcastsRequest{ServiceName: "TestService"})
    68  	if err != nil {
    69  		t.Errorf("ListActiveBroadcasts returned error: %v", err)
    70  	}
    71  
    72  	wantResp := &spb.ListActiveBroadcastsResponse{Broadcasts: []*spb.Broadcast{&br}}
    73  	if !proto.Equal(res, wantResp) {
    74  		t.Errorf("ListActiveBroadcasts error: want [%v] got [%v]", wantResp, res)
    75  	}
    76  }
    77  
    78  func TestMessageStatusAPI(t *testing.T) {
    79  	ctx := context.Background()
    80  
    81  	ts := testserver.Make(t, "server", "AdminServer", nil)
    82  	defer ts.S.Stop()
    83  
    84  	as := admin.NewServer(ts.DS, nil)
    85  
    86  	// A byte slice representing a message id, with 32 zeros.
    87  	bmid0 := make([]byte, 32)
    88  
    89  	gmsRes, err := as.GetMessageStatus(ctx, &spb.GetMessageStatusRequest{
    90  		MessageId: bmid0,
    91  	})
    92  	if err != nil {
    93  		t.Errorf("GetMessageStatus returned an error: %v", err)
    94  	}
    95  
    96  	gmsWant := &spb.GetMessageStatusResponse{}
    97  	if !proto.Equal(gmsRes, gmsWant) {
    98  		t.Errorf("GetMessageStatus error: want [%v] got [%v]", gmsWant, gmsRes)
    99  	}
   100  
   101  	addr := &fspb.Address{
   102  		ServiceName: "TestService",
   103  	}
   104  
   105  	m := &fspb.Message{
   106  		MessageId:    bmid0,
   107  		Source:       addr,
   108  		Destination:  addr,
   109  		CreationTime: &tspb.Timestamp{Seconds: 42},
   110  	}
   111  
   112  	if err := ts.DS.StoreMessages(ctx, []*fspb.Message{m}, ""); err != nil {
   113  		t.Errorf("StoreMessage (Message: [%v]) error: %v", m, err)
   114  	}
   115  
   116  	gmsRes, err = as.GetMessageStatus(ctx, &spb.GetMessageStatusRequest{
   117  		MessageId: bmid0,
   118  	})
   119  	if err != nil {
   120  		t.Errorf("GetMessageStatus returned an error: %v", err)
   121  	}
   122  
   123  	gmsWant = &spb.GetMessageStatusResponse{
   124  		CreationTime: &tspb.Timestamp{Seconds: 42},
   125  	}
   126  	if !proto.Equal(gmsRes, gmsWant) {
   127  		t.Errorf("GetMessageStatus error: want [%v] got [%v]", gmsWant, gmsRes)
   128  	}
   129  }
   130  
   131  type mockStreamClientIdsServer struct {
   132  	grpc.ServerStream
   133  	responses []*spb.StreamClientIdsResponse
   134  }
   135  
   136  func (m *mockStreamClientIdsServer) Send(response *spb.StreamClientIdsResponse) error {
   137  	m.responses = append(m.responses, response)
   138  	return nil
   139  }
   140  
   141  func (m *mockStreamClientIdsServer) Context() context.Context {
   142  	return context.Background()
   143  }
   144  
   145  func TestListClientsAPI(t *testing.T) {
   146  	ctx := context.Background()
   147  
   148  	ts := testserver.Make(t, "server", "AdminServer", nil)
   149  	defer ts.S.Stop()
   150  
   151  	as := admin.NewServer(ts.DS, nil)
   152  
   153  	id0 := []byte{0, 0, 0, 0, 0, 0, 0, 0}
   154  	cid0, err := common.BytesToClientID(id0)
   155  	if err != nil {
   156  		t.Errorf("BytesToClientID(%v) = %v, want nil", cid0, err)
   157  	}
   158  
   159  	if err = ts.DS.AddClient(ctx, cid0, &db.ClientData{}); err != nil {
   160  		t.Fatalf("AddClient returned an error: %v", err)
   161  	}
   162  
   163  	id1 := []byte{0, 0, 0, 0, 0, 0, 0, 1}
   164  	cid1, err := common.BytesToClientID(id1)
   165  	if err != nil {
   166  		t.Errorf("BytesToClientID(%v) = %v, want nil", cid1, err)
   167  	}
   168  
   169  	lab1 := []*fspb.Label{
   170  		{
   171  			ServiceName: "BarService",
   172  			Label:       "BarLabel",
   173  		},
   174  		{
   175  			ServiceName: "FooService",
   176  			Label:       "FooLabel",
   177  		},
   178  	}
   179  
   180  	if err = ts.DS.AddClient(ctx, cid1, &db.ClientData{
   181  		Labels: lab1,
   182  	}); err != nil {
   183  		t.Errorf("AddClient returned an error: %v", err)
   184  	}
   185  
   186  	t.Run("ListClients", func(t *testing.T) {
   187  		lcRes, err := as.ListClients(ctx, &spb.ListClientsRequest{})
   188  		if err != nil {
   189  			t.Errorf("ListClients returned an error: %v", err)
   190  		}
   191  
   192  		lcWant := &spb.ListClientsResponse{
   193  			Clients: []*spb.Client{
   194  				{
   195  					ClientId: id0,
   196  				},
   197  				{
   198  					ClientId: id1,
   199  					Labels:   lab1,
   200  				},
   201  			},
   202  		}
   203  
   204  		// The result's order is arbitrary, so let's sort it.
   205  		sort.Slice(lcRes.Clients, func(i, j int) bool {
   206  			return bytes.Compare(lcRes.Clients[i].ClientId, lcRes.Clients[j].ClientId) < 0
   207  		})
   208  
   209  		for _, c := range lcRes.Clients {
   210  			if c.LastContactTime == nil {
   211  				t.Errorf("ListClients error: LastSeenTimestamp is nil")
   212  			}
   213  			c.LastContactTime = nil
   214  		}
   215  
   216  		if !proto.Equal(lcRes, lcWant) {
   217  			t.Errorf("ListClients error: want [%v], got [%v]", lcWant, lcRes)
   218  		}
   219  	})
   220  
   221  	t.Run("StreamClientIds", func(t *testing.T) {
   222  		var m mockStreamClientIdsServer
   223  		req := &spb.StreamClientIdsRequest{}
   224  		err := as.StreamClientIds(req, &m)
   225  		if err != nil {
   226  			t.Errorf("StreamClientIds returned an error: %v", err)
   227  		}
   228  		var ids [][]byte
   229  		for _, response := range m.responses {
   230  			ids = append(ids, response.ClientId)
   231  		}
   232  		sort.Slice(ids, func(i, j int) bool {
   233  			return bytes.Compare(ids[i], ids[j]) < 0
   234  		})
   235  		expected := [][]byte{id0, id1}
   236  		if !reflect.DeepEqual(ids, expected) {
   237  			t.Errorf("StreamClientIds error: want [%v], got [%v].", expected, ids)
   238  		}
   239  	})
   240  }
   241  
   242  func TestInsertMessageAPI(t *testing.T) {
   243  	mid, err := common.RandomMessageID()
   244  	if err != nil {
   245  		t.Fatalf("Unable to create message id: %v", err)
   246  	}
   247  	ctx := context.Background()
   248  
   249  	ts := testserver.Make(t, "server", "AdminServer", nil)
   250  	defer ts.S.Stop()
   251  
   252  	key, err := ts.AddClient()
   253  	if err != nil {
   254  		t.Fatalf("Unable to add client: %v", err)
   255  	}
   256  	id, err := common.MakeClientID(key)
   257  	if err != nil {
   258  		t.Fatalf("Unable to make ClientID: %v", err)
   259  	}
   260  
   261  	as := admin.NewServer(ts.DS, nil)
   262  
   263  	m := fspb.Message{
   264  		MessageId:    mid.Bytes(),
   265  		Source:       &fspb.Address{ServiceName: "TestService"},
   266  		Destination:  &fspb.Address{ServiceName: "TestService", ClientId: id.Bytes()},
   267  		MessageType:  "DummyType",
   268  		CreationTime: db.NowProto(),
   269  	}
   270  
   271  	if _, err := as.InsertMessage(ctx, proto.Clone(&m).(*fspb.Message)); err != nil {
   272  		t.Fatalf("InsertMessage returned error: %v", err)
   273  	}
   274  	// m should now be available for processing:
   275  	msgs, err := ts.DS.ClientMessagesForProcessing(ctx, id, 10, nil)
   276  	if err != nil {
   277  		t.Fatalf("ClientMessagesForProcessing(%v) returned error: %v", id, err)
   278  	}
   279  	if len(msgs) != 1 || !proto.Equal(msgs[0], &m) {
   280  		t.Errorf("ClientMessagesForProcessing(%v) returned unexpected value, got: %v, want [%v]", id, msgs, m.String())
   281  	}
   282  }
   283  
   284  func TestInsertMessageAPI_LargeMessages(t *testing.T) {
   285  	mid, err := common.RandomMessageID()
   286  	if err != nil {
   287  		t.Fatalf("Unable to create message id: %v", err)
   288  	}
   289  	ctx := context.Background()
   290  
   291  	server := testserver.Make(t, "server", "AdminServer", nil)
   292  	defer server.S.Stop()
   293  
   294  	key, err := server.AddClient()
   295  	if err != nil {
   296  		t.Fatalf("Unable to add client: %v", err)
   297  	}
   298  	id, err := common.MakeClientID(key)
   299  	if err != nil {
   300  		t.Fatalf("Unable to make ClientID: %v", err)
   301  	}
   302  
   303  	adminServer := admin.NewServer(server.DS, nil)
   304  
   305  	msg := fspb.Message{
   306  		MessageId:    mid.Bytes(),
   307  		Source:       &fspb.Address{ServiceName: "TestService"},
   308  		Destination:  &fspb.Address{ServiceName: "TestService", ClientId: id.Bytes()},
   309  		MessageType:  "DummyType",
   310  		CreationTime: db.NowProto(),
   311  		Data: anypbtest.New(t, &fspb.Signature{
   312  			Signature: bytes.Repeat([]byte{0xa}, 2<<20+1),
   313  		}),
   314  	}
   315  
   316  	if _, err := adminServer.InsertMessage(ctx, &msg); err == nil {
   317  		t.Fatal("Expected InsertMessage to return an error.")
   318  	} else if !strings.Contains(err.Error(), "exceeds the 2097152-byte limit") {
   319  		t.Errorf("Unexpected error: [%v].", err)
   320  	}
   321  }
   322  
   323  func TestPendingMessages(t *testing.T) {
   324  	mid0, _ := common.BytesToMessageID([]byte("91234567890123456789012345678900"))
   325  	mid1, _ := common.BytesToMessageID([]byte("91234567890123456789012345678901"))
   326  	mid2, _ := common.BytesToMessageID([]byte("91234567890123456789012345678902"))
   327  	mid3, _ := common.BytesToMessageID([]byte("91234567890123456789012345678903"))
   328  	mid4, _ := common.BytesToMessageID([]byte("91234567890123456789012345678904"))
   329  
   330  	ctx := context.Background()
   331  
   332  	ts := testserver.Make(t, "server", "TestPendingMessages", nil)
   333  	defer ts.S.Stop()
   334  
   335  	key, err := ts.AddClient()
   336  	if err != nil {
   337  		t.Fatalf("Unable to add client: %v", err)
   338  	}
   339  	id, err := common.MakeClientID(key)
   340  	if err != nil {
   341  		t.Fatalf("Unable to make ClientID: %v", err)
   342  	}
   343  
   344  	as := admin.NewServer(ts.DS, nil)
   345  
   346  	msgs := []*fspb.Message{
   347  		{
   348  			MessageId:    mid0.Bytes(),
   349  			Source:       &fspb.Address{ServiceName: "TestService"},
   350  			Destination:  &fspb.Address{ServiceName: "TestService", ClientId: id.Bytes()},
   351  			MessageType:  "DummyType",
   352  			CreationTime: db.NowProto(),
   353  			Data: &anypb.Any{
   354  				TypeUrl: "test data proto urn 0",
   355  				Value:   []byte("Test data proto 0")},
   356  		},
   357  		{
   358  			MessageId:    mid1.Bytes(),
   359  			Source:       &fspb.Address{ServiceName: "TestService"},
   360  			Destination:  &fspb.Address{ServiceName: "TestService", ClientId: id.Bytes()},
   361  			MessageType:  "DummyType",
   362  			CreationTime: db.NowProto(),
   363  			Data: &anypb.Any{
   364  				TypeUrl: "test data proto urn 1",
   365  				Value:   []byte("Test data proto 1")},
   366  		},
   367  		{
   368  			MessageId:    mid2.Bytes(),
   369  			Source:       &fspb.Address{ServiceName: "TestService"},
   370  			Destination:  &fspb.Address{ServiceName: "TestService", ClientId: id.Bytes()},
   371  			MessageType:  "DummyType",
   372  			CreationTime: db.NowProto(),
   373  			Data: &anypb.Any{
   374  				TypeUrl: "test data proto urn 2",
   375  				Value:   []byte("Test data proto 2")},
   376  		},
   377  		{
   378  			MessageId:    mid3.Bytes(),
   379  			Source:       &fspb.Address{ServiceName: "TestService"},
   380  			Destination:  &fspb.Address{ServiceName: "TestService", ClientId: id.Bytes()},
   381  			MessageType:  "DummyType",
   382  			CreationTime: db.NowProto(),
   383  			Data: &anypb.Any{
   384  				TypeUrl: "test data proto urn 3",
   385  				Value:   []byte("Test data proto 3")},
   386  		},
   387  		{
   388  			MessageId:    mid4.Bytes(),
   389  			Source:       &fspb.Address{ServiceName: "TestService"},
   390  			Destination:  &fspb.Address{ServiceName: "TestService", ClientId: id.Bytes()},
   391  			MessageType:  "DummyType",
   392  			CreationTime: db.NowProto(),
   393  			Data: &anypb.Any{
   394  				TypeUrl: "test data proto urn 4",
   395  				Value:   []byte("Test data proto 4")},
   396  		},
   397  	}
   398  
   399  	for _, m := range msgs {
   400  		if _, err := as.InsertMessage(ctx, proto.Clone(m).(*fspb.Message)); err != nil {
   401  			t.Fatalf("InsertMessage returned error: %v", err)
   402  		}
   403  	}
   404  
   405  	// Get the message from the pending messages count
   406  
   407  	t.Run("GetPendingMessageCount", func(t *testing.T) {
   408  		greq := &spb.GetPendingMessageCountRequest{
   409  			ClientIds: [][]byte{id.Bytes()},
   410  		}
   411  		gresp, err := as.GetPendingMessageCount(ctx, greq)
   412  		if err != nil {
   413  			t.Fatalf("GetPendingMessageCount returned error: %v", err)
   414  		}
   415  		if gresp.Count != uint64(len(msgs)) {
   416  			t.Fatalf("Bad resul.t Expected: %v. Got: %v.", len(msgs), gresp.Count)
   417  		}
   418  	})
   419  
   420  	// Get the message from the pending messages list, with data
   421  
   422  	t.Run("GetPendingMessages/WantData=true", func(t *testing.T) {
   423  		greq := &spb.GetPendingMessagesRequest{
   424  			ClientIds: [][]byte{id.Bytes()},
   425  			WantData:  true,
   426  		}
   427  		gresp, err := as.GetPendingMessages(ctx, greq)
   428  		if err != nil {
   429  			t.Fatalf("GetPendingMessages returned error: %v", err)
   430  		}
   431  		if len(gresp.Messages) != len(msgs) {
   432  			t.Fatalf("Bad size of returned messages. Expected: %v. Got: %v.", len(msgs), len(gresp.Messages))
   433  		}
   434  		for i, msg := range msgs {
   435  			if !proto.Equal(gresp.Messages[i], msg) {
   436  				t.Fatalf("Got bad message. Expected: [%v]. Got: [%v].", msg, gresp.Messages[i])
   437  			}
   438  		}
   439  	})
   440  
   441  	// Get the message from the pending messages list, with data, limit and offset
   442  
   443  	t.Run("GetPendingMessages/WantData=true/limit/offset", func(t *testing.T) {
   444  		greq := &spb.GetPendingMessagesRequest{
   445  			ClientIds: [][]byte{id.Bytes()},
   446  			WantData:  true,
   447  			Offset:    1,
   448  			Limit:     2,
   449  		}
   450  		gresp, err := as.GetPendingMessages(ctx, greq)
   451  		if err != nil {
   452  			t.Fatalf("GetPendingMessages returned error: %v", err)
   453  		}
   454  		if len(gresp.Messages) != 2 {
   455  			t.Fatalf("Bad size of returned messages. Expected: %v. Got: %v.", len(msgs), len(gresp.Messages))
   456  		}
   457  		for i := range 2 {
   458  			if !proto.Equal(gresp.Messages[i], msgs[1+i]) {
   459  				t.Fatalf("Got bad message. Expected: [%v]. Got: [%v].", msgs[1+i], gresp.Messages[i])
   460  			}
   461  		}
   462  	})
   463  
   464  	// Get the message from the pending messages list, without data
   465  
   466  	t.Run("GetPendingMessages/WantData=false", func(t *testing.T) {
   467  		greq := &spb.GetPendingMessagesRequest{
   468  			ClientIds: [][]byte{id.Bytes()},
   469  			WantData:  false,
   470  		}
   471  		gresp, err := as.GetPendingMessages(ctx, greq)
   472  
   473  		if err != nil {
   474  			t.Fatalf("GetPendingMessages returned error: %v", err)
   475  		}
   476  		if len(gresp.Messages) != len(msgs) {
   477  			t.Fatalf("Bad size of returned messages. Expected: %v. Got: %v.", len(msgs), len(gresp.Messages))
   478  		}
   479  		for i, msg := range msgs {
   480  			expectedMessage := proto.Clone(msg).(*fspb.Message)
   481  			expectedMessage.Data = nil
   482  			if !proto.Equal(gresp.Messages[i], expectedMessage) {
   483  				t.Fatalf("Got bad message. Expected: [%v]. Got: [%v].", expectedMessage, gresp.Messages[i])
   484  			}
   485  		}
   486  	})
   487  
   488  	// Delete the message from the pending messages list.
   489  
   490  	t.Run("DeletePendingMessages", func(t *testing.T) {
   491  		req := &spb.DeletePendingMessagesRequest{
   492  			ClientIds: [][]byte{id.Bytes()},
   493  		}
   494  		if _, err := as.DeletePendingMessages(ctx, req); err != nil {
   495  			t.Fatalf("DeletePendingMessages returned error: %v", err)
   496  		}
   497  
   498  		// ClientMessagesForProcessing should return nothing, since the message is
   499  		// supposed to be deleted from the pending mesages list.
   500  		msgs, err := ts.DS.ClientMessagesForProcessing(ctx, id, 10, nil)
   501  		if err != nil {
   502  			t.Fatalf("ClientMessagesForProcessing(%v) returned error: %v", id, err)
   503  		}
   504  		if len(msgs) != 0 {
   505  			t.Errorf("ClientMessagesForProcessing(%v) was expected to return 0 messages, got: %v", id, msgs)
   506  		}
   507  	})
   508  }
   509  
   510  type mockStreamClientContactsServer struct {
   511  	grpc.ServerStream
   512  	responses []*spb.StreamClientContactsResponse
   513  }
   514  
   515  func (m *mockStreamClientContactsServer) Send(response *spb.StreamClientContactsResponse) error {
   516  	m.responses = append(m.responses, response)
   517  	return nil
   518  }
   519  
   520  func (m *mockStreamClientContactsServer) Context() context.Context {
   521  	return context.Background()
   522  }
   523  
   524  func TestClientContacts(t *testing.T) {
   525  	ctx := context.Background()
   526  
   527  	ts := testserver.Make(t, "server", "TestPendingMessages", nil)
   528  	defer ts.S.Stop()
   529  
   530  	as := admin.NewServer(ts.DS, nil)
   531  
   532  	id := []byte{0, 0, 0, 0, 0, 0, 0, 0}
   533  	cid, err := common.BytesToClientID(id)
   534  	if err != nil {
   535  		t.Fatalf("Failed to convert ID: %v.", err)
   536  	}
   537  
   538  	err = ts.DS.AddClient(ctx, cid, &db.ClientData{})
   539  	if err != nil {
   540  		t.Fatalf("Failed to add client: %v.", err)
   541  	}
   542  
   543  	for _, data := range []db.ContactData{
   544  		{
   545  			ClientID: cid,
   546  			Addr:     "a1",
   547  		},
   548  		{
   549  			ClientID: cid,
   550  			Addr:     "a2",
   551  		},
   552  	} {
   553  		_, err = ts.DS.RecordClientContact(ctx, data)
   554  		if err != nil {
   555  			t.Fatalf("Failed to record client contact: %v.", err)
   556  		}
   557  	}
   558  
   559  	t.Run("StreamClientContacts", func(t *testing.T) {
   560  		req := &spb.StreamClientContactsRequest{
   561  			ClientId: id,
   562  		}
   563  		var m mockStreamClientContactsServer
   564  		err := as.StreamClientContacts(req, &m)
   565  		if err != nil {
   566  			t.Fatalf("StreamClientContacts failed: %v.", err)
   567  		}
   568  
   569  		var addrs []string
   570  		for _, response := range m.responses {
   571  			addrs = append(addrs, response.Contact.ObservedAddress)
   572  		}
   573  		sort.Strings(addrs)
   574  
   575  		expected := []string{"a1", "a2"}
   576  
   577  		if !reflect.DeepEqual(addrs, expected) {
   578  			t.Errorf("StreamClientContacts error: want [%v], got [%v].", expected, addrs)
   579  		}
   580  	})
   581  
   582  }