github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/client/comms_test.go (about)

     1  // Copyright 2024 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"io"
    21  	"sync/atomic"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/google/fleetspeak/fleetspeak/src/client/comms"
    26  	"github.com/google/fleetspeak/fleetspeak/src/client/config"
    27  	"github.com/google/fleetspeak/fleetspeak/src/client/service"
    28  	"github.com/google/fleetspeak/fleetspeak/src/client/stats"
    29  	"github.com/google/fleetspeak/fleetspeak/src/common"
    30  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    31  )
    32  
    33  type testStatsCollector struct {
    34  	stats.NoopCollector
    35  	created, processed atomic.Int32
    36  }
    37  
    38  // ContactDataCreated implements stats.CommsContextCollector
    39  func (c *testStatsCollector) ContactDataCreated(wcd *fspb.WrappedContactData, err error) {
    40  	c.created.Add(1)
    41  }
    42  
    43  // ContactDataCreated implements stats.CommsContextCollector
    44  func (c *testStatsCollector) ContactDataProcessed(cd *fspb.ContactData, streaming bool, err error) {
    45  	c.processed.Add(1)
    46  }
    47  
    48  type testCommunicator struct {
    49  	comms.Communicator
    50  	t *testing.T
    51  
    52  	cctx   comms.Context
    53  	cancel context.CancelFunc
    54  }
    55  
    56  func (c *testCommunicator) Setup(cctx comms.Context) error {
    57  	c.cctx = cctx
    58  	return nil
    59  }
    60  
    61  func (c *testCommunicator) Start() error {
    62  	ctx, cancel := context.WithCancel(context.Background())
    63  	c.cancel = cancel
    64  	go func() {
    65  		for {
    66  			select {
    67  			case <-ctx.Done():
    68  				return
    69  			case msgi := <-c.cctx.Outbox():
    70  				if err := c.processMessage(ctx, msgi); err != nil {
    71  					c.t.Errorf("Error while processing message: %v", err)
    72  					cancel()
    73  				}
    74  			}
    75  		}
    76  	}()
    77  	return nil
    78  }
    79  
    80  func (c *testCommunicator) Stop() {
    81  	c.cancel()
    82  }
    83  
    84  func (c *testCommunicator) GetFileIfModified(ctx context.Context, service, name string, modSince time.Time) (data io.ReadCloser, mod time.Time, err error) {
    85  	// GetFileIfModified gets called by the client to poll list of revoked certs,
    86  	// returning an error causes the client to continue without it.
    87  	return nil, time.Time{}, errors.New("file unavailable")
    88  }
    89  
    90  func (c *testCommunicator) processMessage(ctx context.Context, msgi comms.MessageInfo) error {
    91  	defer msgi.Ack()
    92  	if msgi.M.GetDestination().GetServiceName() != "RemoteService" {
    93  		return nil
    94  	}
    95  	// Simulate sending the message to the server
    96  	if _, _, err := c.cctx.MakeContactData([]*fspb.Message{msgi.M}, nil); err != nil {
    97  		return err
    98  	}
    99  	// A real communicator would send the returned WrappedContactData to the server now
   100  
   101  	// Simulate a response from the server service addressed to a client service.
   102  	mid, err := common.RandomMessageID()
   103  	if err != nil {
   104  		return err
   105  	}
   106  	cd := &fspb.ContactData{
   107  		Messages: []*fspb.Message{&fspb.Message{
   108  			MessageId: mid.Bytes(),
   109  			Source:    &fspb.Address{ServiceName: "RemoteService"},
   110  			Destination: &fspb.Address{
   111  				ClientId:    c.cctx.CurrentID().Bytes(),
   112  				ServiceName: "NOOPService",
   113  			},
   114  		}},
   115  	}
   116  	return c.cctx.ProcessContactData(ctx, cd, false)
   117  }
   118  
   119  func TestCommsContextReportsStats(t *testing.T) {
   120  	sc := &testStatsCollector{}
   121  
   122  	// Create client with the testCommunicator and a NOOP client service.
   123  	cl, err := New(
   124  		config.Configuration{
   125  			FixedServices: []*fspb.ClientServiceConfig{{
   126  				Name:    "NOOPService",
   127  				Factory: "NOOP",
   128  			}},
   129  		},
   130  		Components{
   131  			ServiceFactories: map[string]service.Factory{
   132  				"NOOP": service.NOOPFactory,
   133  			},
   134  			Communicator: &testCommunicator{t: t},
   135  			Stats:        sc,
   136  		})
   137  	if err != nil {
   138  		t.Fatalf("Unable to create client: %v", err)
   139  	}
   140  	defer cl.Stop()
   141  
   142  	mid, err := common.RandomMessageID()
   143  	if err != nil {
   144  		t.Fatalf("Unable to create message id: %v", err)
   145  	}
   146  
   147  	// Simulate a message coming from a client service addressed to a server service,
   148  	// This invokes the testCommunicator.
   149  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   150  	defer cancel()
   151  	msg := service.AckMessage{
   152  		M: &fspb.Message{
   153  			MessageId: mid.Bytes(),
   154  			Source: &fspb.Address{
   155  				ClientId:    cl.config.ClientID().Bytes(),
   156  				ServiceName: "NOOPService",
   157  			},
   158  			Destination: &fspb.Address{ServiceName: "RemoteService"},
   159  		},
   160  		Ack: cancel,
   161  	}
   162  	if err := cl.ProcessMessage(ctx, msg); err != nil {
   163  		t.Fatalf("Unable to process message: %v", err)
   164  	}
   165  
   166  	// Wait for the Ack callback which gets called when testCommunicator.processMessage returns
   167  	<-ctx.Done()
   168  
   169  	created := sc.created.Load()
   170  	if created != 1 {
   171  		t.Errorf("Got %d contact data created, want 1", created)
   172  	}
   173  	processed := sc.processed.Load()
   174  	if processed != 1 {
   175  		t.Errorf("Got %d contact data processed, want 1", processed)
   176  	}
   177  }