github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/inttesting/frr/frr_test.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frr
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/hex"
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"reflect"
    25  	"sort"
    26  	"testing"
    27  	"time"
    28  
    29  	log "github.com/golang/glog"
    30  	"google.golang.org/grpc"
    31  	"google.golang.org/protobuf/proto"
    32  
    33  	cservice "github.com/google/fleetspeak/fleetspeak/src/client/service"
    34  	"github.com/google/fleetspeak/fleetspeak/src/common"
    35  	"github.com/google/fleetspeak/fleetspeak/src/common/anypbtest"
    36  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    37  	"github.com/google/fleetspeak/fleetspeak/src/server/service"
    38  
    39  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    40  	fgrpc "github.com/google/fleetspeak/fleetspeak/src/inttesting/frr/proto/fleetspeak_frr"
    41  	fpb "github.com/google/fleetspeak/fleetspeak/src/inttesting/frr/proto/fleetspeak_frr"
    42  	srpb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    43  )
    44  
    45  type fakeClientServiceContext struct {
    46  	cservice.Context
    47  	o chan *fspb.Message
    48  }
    49  
    50  func (f fakeClientServiceContext) Send(ctx context.Context, m cservice.AckMessage) error {
    51  	select {
    52  	case f.o <- m.M:
    53  		if m.Ack != nil {
    54  			m.Ack()
    55  		}
    56  		return nil
    57  	case <-ctx.Done():
    58  		return ctx.Err()
    59  	}
    60  }
    61  
    62  func (f fakeClientServiceContext) GetLocalInfo() *cservice.LocalInfo {
    63  	id, err := common.StringToClientID("0000000000000000")
    64  	if err != nil {
    65  		log.Fatal(err)
    66  	}
    67  	return &cservice.LocalInfo{
    68  		ClientID: id,
    69  	}
    70  }
    71  
    72  func (f fakeClientServiceContext) GetFileIfModified(ctx context.Context, name string, modSince time.Time) (io.ReadCloser, time.Time, error) {
    73  	if name == "TestFile" {
    74  		return io.NopCloser(bytes.NewReader([]byte("Test file data."))), db.Now(), nil
    75  	}
    76  	return nil, time.Time{}, fmt.Errorf("File not found: %v", name)
    77  }
    78  
    79  func clearChannel(c chan *fspb.Message) {
    80  	for {
    81  		select {
    82  		case <-c:
    83  			continue
    84  		default:
    85  			return
    86  		}
    87  	}
    88  }
    89  
    90  func TestClientService(t *testing.T) {
    91  	cs, err := ClientServiceFactory(nil)
    92  	defer cs.Stop()
    93  	if err != nil {
    94  		t.Fatalf("Unable to create service: %v", err)
    95  	}
    96  
    97  	out := make(chan *fspb.Message, 10)
    98  	if err := cs.Start(&fakeClientServiceContext{o: out}); err != nil {
    99  		t.Fatalf("Unable to start service: %v", err)
   100  	}
   101  
   102  	for _, tc := range []struct {
   103  		rd        *fpb.TrafficRequestData
   104  		wantCount int
   105  		wantSize  int
   106  	}{
   107  		{
   108  			rd:        &fpb.TrafficRequestData{RequestId: 0},
   109  			wantCount: 1,
   110  			wantSize:  1024,
   111  		},
   112  		{
   113  			rd:        &fpb.TrafficRequestData{RequestId: 1, NumMessages: 5},
   114  			wantCount: 5,
   115  			wantSize:  1024,
   116  		},
   117  	} {
   118  		t.Run(fmt.Sprintf("TrafficRequestData%v", tc.wantCount), func(t *testing.T) {
   119  			if err := cs.ProcessMessage(context.Background(), &fspb.Message{
   120  				MessageType: "TrafficRequest",
   121  				Data:        anypbtest.New(t, tc.rd),
   122  			}); err != nil {
   123  				t.Fatalf("unable to process message [%v]: %v", tc.rd.String(), err)
   124  			}
   125  			for i := range tc.wantCount {
   126  				res := <-out
   127  				d := &fpb.TrafficResponseData{}
   128  				if err := res.Data.UnmarshalTo(d); err != nil {
   129  					t.Errorf("unable to unmarshal data")
   130  					clearChannel(out)
   131  					break
   132  				}
   133  				if d.RequestId != tc.rd.RequestId {
   134  					t.Errorf("expected response for request %v got response for request %v", tc.rd.RequestId, d.RequestId)
   135  					clearChannel(out)
   136  					break
   137  				}
   138  				if len(d.Data) != tc.wantSize {
   139  					t.Errorf("wanted data size of %v got size of %v", tc.wantSize, len(d.Data))
   140  					clearChannel(out)
   141  					break
   142  				}
   143  				// Last message should have the end marker.
   144  				wantFin := i == tc.wantCount-1
   145  				if d.Fin != wantFin {
   146  					t.Errorf("wanted Fin: %v got Fin: %v", wantFin, d.Fin)
   147  					clearChannel(out)
   148  					break
   149  				}
   150  			}
   151  		})
   152  	}
   153  
   154  	if err := cs.ProcessMessage(context.Background(), &fspb.Message{
   155  		MessageType: "FileRequest",
   156  		Data: anypbtest.New(t, &fpb.FileRequestData{
   157  			MasterId: 42,
   158  			Name:     "TestFile",
   159  		}),
   160  	}); err != nil {
   161  		t.Fatalf("Unable to process FileRequest: %v", err)
   162  	}
   163  	res := <-out
   164  	got := &fpb.FileResponseData{}
   165  	if err := res.Data.UnmarshalTo(got); err != nil {
   166  		t.Errorf("Unable unmarshal FileResponse: %v", err)
   167  	}
   168  	want := &fpb.FileResponseData{
   169  		MasterId: 42,
   170  		Name:     "TestFile",
   171  		Size:     15,
   172  	}
   173  	if !proto.Equal(want, got) {
   174  		t.Errorf("Unexpected FileResponse, want: %v, got %v", want, got)
   175  	}
   176  }
   177  
   178  func TestClientServiceEarlyShutdown(t *testing.T) {
   179  	cs, err := ClientServiceFactory(nil)
   180  	if err != nil {
   181  		t.Fatalf("Unable to create service: %v", err)
   182  	}
   183  	// no buffering so Send will block.
   184  	out := make(chan *fspb.Message)
   185  	if err := cs.Start(&fakeClientServiceContext{o: out}); err != nil {
   186  		t.Fatalf("Unable to start service: %v", err)
   187  	}
   188  
   189  	if err := cs.ProcessMessage(context.Background(), &fspb.Message{
   190  		MessageType: "TrafficRequest",
   191  		Data: anypbtest.New(t, &fpb.TrafficRequestData{
   192  			RequestId:   1,
   193  			NumMessages: 5,
   194  		}),
   195  	}); err != nil {
   196  		t.Error("unable to process message")
   197  	}
   198  
   199  	// Service should be blocked trying to put a message into out. However, Stop should
   200  	// shut everything down and return quickly.
   201  	cs.Stop()
   202  }
   203  
   204  type fakeMasterServer struct {
   205  	fgrpc.UnimplementedMasterServer
   206  
   207  	rec     chan<- *fpb.MessageInfo
   208  	recFile chan<- *fpb.FileResponseInfo
   209  }
   210  
   211  func (s fakeMasterServer) RecordTrafficResponse(ctx context.Context, i *fpb.MessageInfo) (*fspb.EmptyMessage, error) {
   212  	log.Infof("recording m: %v", i)
   213  	s.rec <- i
   214  	return &fspb.EmptyMessage{}, nil
   215  }
   216  
   217  func (s fakeMasterServer) RecordFileResponse(ctx context.Context, i *fpb.FileResponseInfo) (*fspb.EmptyMessage, error) {
   218  	log.Infof("recording m: %v", i)
   219  	s.recFile <- i
   220  	return &fspb.EmptyMessage{}, nil
   221  }
   222  
   223  func (s fakeMasterServer) CompletedRequests(ctx context.Context, c *fpb.CompletedRequestsRequest) (*fpb.CompletedRequestsResponse, error) {
   224  	return &fpb.CompletedRequestsResponse{}, nil
   225  }
   226  
   227  func (s fakeMasterServer) CreateHunt(ctx context.Context, hr *fpb.CreateHuntRequest) (*fpb.CreateHuntResponse, error) {
   228  	return &fpb.CreateHuntResponse{}, nil
   229  }
   230  
   231  type fakeServiceContext struct {
   232  	service.Context
   233  }
   234  
   235  func TestServerService(t *testing.T) {
   236  	// A channel to collect messages received by the fake master server.
   237  	rec := make(chan *fpb.MessageInfo, 10)
   238  
   239  	// Create a fake master server as a real local grpc service.
   240  	s := grpc.NewServer()
   241  	fgrpc.RegisterMasterServer(s, fakeMasterServer{rec: rec})
   242  	ad, err := net.ResolveTCPAddr("tcp", "localhost:0")
   243  	if err != nil {
   244  		t.Fatal(err)
   245  	}
   246  	tl, err := net.ListenTCP("tcp", ad)
   247  	if err != nil {
   248  		t.Fatal(err)
   249  	}
   250  	defer s.Stop()
   251  	go func() {
   252  		log.Infof("Finished with: %v", s.Serve(tl))
   253  	}()
   254  
   255  	// Directly create and start a FRR service.
   256  	se, err := ServerServiceFactory(&srpb.ServiceConfig{
   257  		Config: anypbtest.New(t, &fpb.Config{
   258  			MasterServer: tl.Addr().String(),
   259  		}),
   260  	})
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	if err := se.Start(&fakeServiceContext{}); err != nil {
   265  		t.Fatal(err)
   266  	}
   267  	defer se.Stop()
   268  
   269  	// Build a message.
   270  	id, err := common.BytesToClientID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
   271  	if err != nil {
   272  		t.Fatal(err)
   273  	}
   274  	rd := &fpb.TrafficResponseData{
   275  		RequestId:     42,
   276  		ResponseIndex: 24,
   277  		Data:          []byte("asdf"),
   278  		Fin:           true,
   279  	}
   280  	msg := &fspb.Message{
   281  		Source: &fspb.Address{
   282  			ClientId:    id.Bytes(),
   283  			ServiceName: "FRR",
   284  		},
   285  		Destination: &fspb.Address{
   286  			ServiceName: "FRR",
   287  		},
   288  		MessageType: "TrafficResponse",
   289  		Data:        anypbtest.New(t, rd),
   290  	}
   291  
   292  	// Process it.
   293  	if err := se.ProcessMessage(context.Background(), msg); err != nil {
   294  		t.Fatalf("se.ProcessMessage(%+v) = %v", msg, err)
   295  	}
   296  
   297  	// Check that it was received.
   298  	mi := <-rec
   299  	if !bytes.Equal(mi.ClientId, id.Bytes()) {
   300  		t.Errorf("Unexpected client id, got [%v], want [%v]", hex.EncodeToString(mi.ClientId), hex.EncodeToString(id.Bytes()))
   301  	}
   302  
   303  	rd.Data = nil
   304  	if !proto.Equal(mi.Data, rd) {
   305  		t.Errorf("Unexpected TrafficRequestData, got [%v], want [%v]", mi.Data, rd.String())
   306  	}
   307  }
   308  
   309  type Int64Slice []int64
   310  
   311  func (p Int64Slice) Len() int           { return len(p) }
   312  func (p Int64Slice) Less(i, j int) bool { return p[i] < p[j] }
   313  func (p Int64Slice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
   314  
   315  func TestMasterServer(t *testing.T) {
   316  	ctx := context.Background()
   317  	ms := NewMasterServer(nil)
   318  	ch := ms.WatchCompleted()
   319  
   320  	id, err := common.BytesToClientID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
   321  	if err != nil {
   322  		t.Fatal(err)
   323  	}
   324  
   325  	for _, tc := range []struct {
   326  		mi        *fpb.MessageInfo
   327  		completed bool
   328  	}{
   329  		{
   330  			mi: &fpb.MessageInfo{
   331  				ClientId: id.Bytes(),
   332  				Data: &fpb.TrafficResponseData{
   333  					MasterId:      ms.masterID,
   334  					RequestId:     0,
   335  					ResponseIndex: 0,
   336  					Fin:           true,
   337  				},
   338  			},
   339  			completed: true,
   340  		},
   341  		{
   342  			mi: &fpb.MessageInfo{
   343  				ClientId: id.Bytes(),
   344  				Data: &fpb.TrafficResponseData{
   345  					MasterId:      ms.masterID,
   346  					RequestId:     1,
   347  					ResponseIndex: 0,
   348  					Fin:           false,
   349  				},
   350  			},
   351  			completed: false,
   352  		},
   353  		{
   354  			mi: &fpb.MessageInfo{
   355  				ClientId: id.Bytes(),
   356  				Data: &fpb.TrafficResponseData{
   357  					MasterId:      ms.masterID,
   358  					RequestId:     1,
   359  					ResponseIndex: 1,
   360  					Fin:           true,
   361  				},
   362  			},
   363  			completed: true,
   364  		},
   365  		{
   366  			mi: &fpb.MessageInfo{
   367  				ClientId: id.Bytes(),
   368  				Data: &fpb.TrafficResponseData{
   369  					MasterId:      ms.masterID,
   370  					RequestId:     2,
   371  					ResponseIndex: 1,
   372  					Fin:           true,
   373  				},
   374  			},
   375  			completed: false,
   376  		},
   377  		{
   378  			mi: &fpb.MessageInfo{
   379  				ClientId: id.Bytes(),
   380  				Data: &fpb.TrafficResponseData{
   381  					MasterId:      ms.masterID,
   382  					RequestId:     2,
   383  					ResponseIndex: 0,
   384  					Fin:           false,
   385  				},
   386  			},
   387  			completed: true,
   388  		},
   389  		{
   390  			mi: &fpb.MessageInfo{
   391  				ClientId: id.Bytes(),
   392  				Data: &fpb.TrafficResponseData{
   393  					MasterId:      ms.masterID,
   394  					RequestId:     3,
   395  					ResponseIndex: 3,
   396  					Fin:           true,
   397  				},
   398  			},
   399  			completed: false,
   400  		},
   401  		{
   402  			mi: &fpb.MessageInfo{
   403  				ClientId: id.Bytes(),
   404  				Data: &fpb.TrafficResponseData{
   405  					MasterId:      ms.masterID,
   406  					RequestId:     3,
   407  					ResponseIndex: 2,
   408  					Fin:           false,
   409  				},
   410  			},
   411  			completed: false,
   412  		},
   413  		{
   414  			mi: &fpb.MessageInfo{
   415  				ClientId: id.Bytes(),
   416  				Data: &fpb.TrafficResponseData{
   417  					MasterId:      ms.masterID,
   418  					RequestId:     3,
   419  					ResponseIndex: 2,
   420  					Fin:           false,
   421  				},
   422  			},
   423  			completed: false,
   424  		},
   425  		{
   426  			mi: &fpb.MessageInfo{
   427  				ClientId: id.Bytes(),
   428  				Data: &fpb.TrafficResponseData{
   429  					MasterId:      ms.masterID,
   430  					RequestId:     3,
   431  					ResponseIndex: 1,
   432  					Fin:           false,
   433  				},
   434  			},
   435  			completed: false,
   436  		},
   437  		{
   438  			mi: &fpb.MessageInfo{
   439  				ClientId: id.Bytes(),
   440  				Data: &fpb.TrafficResponseData{
   441  					MasterId:      ms.masterID,
   442  					RequestId:     3,
   443  					ResponseIndex: 0,
   444  					Fin:           false,
   445  				},
   446  			},
   447  			completed: true,
   448  		},
   449  		{
   450  			mi: &fpb.MessageInfo{
   451  				ClientId: id.Bytes(),
   452  				Data: &fpb.TrafficResponseData{
   453  					MasterId:      ms.masterID,
   454  					RequestId:     4,
   455  					ResponseIndex: 4,
   456  					Fin:           true,
   457  				},
   458  			},
   459  			completed: false,
   460  		},
   461  	} {
   462  		if _, err := ms.RecordTrafficResponse(ctx, tc.mi); err != nil {
   463  			t.Errorf("Unexpected error recording message [%v]: %v", tc.mi.String(), err)
   464  		}
   465  		var gotComp bool
   466  		select {
   467  		case <-ch:
   468  			gotComp = true
   469  		default:
   470  			gotComp = false
   471  		}
   472  		if tc.completed != gotComp {
   473  			t.Errorf("For request %v, index %v, got completed=%v want %v", tc.mi.Data.RequestId, tc.mi.Data.ResponseIndex, gotComp, tc.completed)
   474  		}
   475  	}
   476  
   477  	c := ms.GetCompletedRequests(id)
   478  	sort.Sort(Int64Slice(c))
   479  	wantCompleted := []int64{0, 1, 2, 3}
   480  	if !reflect.DeepEqual(c, wantCompleted) {
   481  		t.Errorf("Unexpected completed requests list: got %v, want %v", c, wantCompleted)
   482  	}
   483  
   484  	c = ms.AllRequests(id)
   485  	sort.Sort(Int64Slice(c))
   486  	wantAll := []int64{0, 1, 2, 3, 4}
   487  	if !reflect.DeepEqual(c, wantAll) {
   488  		t.Errorf("Unexpected all requests list: got %v, want %v", c, wantAll)
   489  	}
   490  
   491  }