google.golang.org/grpc@v1.72.2/interop/stress/client/main.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // client starts an interop client to do stress test and a metrics server to report qps.
    20  package main
    21  
    22  import (
    23  	"context"
    24  	"flag"
    25  	"fmt"
    26  	rand "math/rand/v2"
    27  	"net"
    28  	"os"
    29  	"strconv"
    30  	"strings"
    31  	"sync"
    32  	"sync/atomic"
    33  	"time"
    34  
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/credentials"
    38  	"google.golang.org/grpc/credentials/google"
    39  	"google.golang.org/grpc/credentials/insecure"
    40  	"google.golang.org/grpc/grpclog"
    41  	"google.golang.org/grpc/interop"
    42  	"google.golang.org/grpc/resolver"
    43  	"google.golang.org/grpc/status"
    44  	"google.golang.org/grpc/testdata"
    45  
    46  	_ "google.golang.org/grpc/xds/googledirectpath" // Register xDS resolver required for c2p directpath.
    47  
    48  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    49  	metricspb "google.golang.org/grpc/interop/stress/grpc_testing"
    50  )
    51  
    52  const (
    53  	googleDefaultCredsName = "google_default_credentials"
    54  	computeEngineCredsName = "compute_engine_channel_creds"
    55  )
    56  
    57  var (
    58  	serverAddresses       = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
    59  	testCases             = flag.String("test_cases", "", "a list of test cases along with the relative weights")
    60  	testDurationSecs      = flag.Int("test_duration_secs", -1, "test duration in seconds")
    61  	numChannelsPerServer  = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
    62  	numStubsPerChannel    = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
    63  	metricsPort           = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
    64  	useTLS                = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
    65  	testCA                = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
    66  	tlsServerName         = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
    67  	caFile                = flag.String("ca_file", "", "The file containing the CA root cert file")
    68  	customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use")
    69  
    70  	totalNumCalls int64
    71  	logger        = grpclog.Component("stress")
    72  )
    73  
    74  // testCaseWithWeight contains the test case type and its weight.
    75  type testCaseWithWeight struct {
    76  	name   string
    77  	weight int
    78  }
    79  
    80  // parseTestCases converts test case string to a list of struct testCaseWithWeight.
    81  func parseTestCases(testCaseString string) []testCaseWithWeight {
    82  	testCaseStrings := strings.Split(testCaseString, ",")
    83  	testCases := make([]testCaseWithWeight, len(testCaseStrings))
    84  	for i, str := range testCaseStrings {
    85  		testCaseNameAndWeight := strings.Split(str, ":")
    86  		if len(testCaseNameAndWeight) != 2 {
    87  			panic(fmt.Sprintf("invalid test case with weight: %s", str))
    88  		}
    89  		// Check if test case is supported.
    90  		testCaseName := strings.ToLower(testCaseNameAndWeight[0])
    91  		switch testCaseName {
    92  		case
    93  			"empty_unary",
    94  			"large_unary",
    95  			"client_streaming",
    96  			"server_streaming",
    97  			"ping_pong",
    98  			"empty_stream",
    99  			"timeout_on_sleeping_server",
   100  			"cancel_after_begin",
   101  			"cancel_after_first_response",
   102  			"status_code_and_message",
   103  			"custom_metadata":
   104  		default:
   105  			panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0]))
   106  		}
   107  		testCases[i].name = testCaseName
   108  		w, err := strconv.Atoi(testCaseNameAndWeight[1])
   109  		if err != nil {
   110  			panic(fmt.Sprintf("%v", err))
   111  		}
   112  		testCases[i].weight = w
   113  	}
   114  	return testCases
   115  }
   116  
   117  // weightedRandomTestSelector defines a weighted random selector for test case types.
   118  type weightedRandomTestSelector struct {
   119  	tests       []testCaseWithWeight
   120  	totalWeight int
   121  }
   122  
   123  // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
   124  func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
   125  	var totalWeight int
   126  	for _, t := range tests {
   127  		totalWeight += t.weight
   128  	}
   129  	return &weightedRandomTestSelector{tests, totalWeight}
   130  }
   131  
   132  func (selector weightedRandomTestSelector) getNextTest() string {
   133  	random := rand.IntN(selector.totalWeight)
   134  	var weightSofar int
   135  	for _, test := range selector.tests {
   136  		weightSofar += test.weight
   137  		if random < weightSofar {
   138  			return test.name
   139  		}
   140  	}
   141  	panic("no test case selected by weightedRandomTestSelector")
   142  }
   143  
   144  // gauge stores the qps of one interop client (one stub).
   145  type gauge struct {
   146  	mutex sync.RWMutex
   147  	val   int64
   148  }
   149  
   150  func (g *gauge) set(v int64) {
   151  	g.mutex.Lock()
   152  	defer g.mutex.Unlock()
   153  	g.val = v
   154  }
   155  
   156  func (g *gauge) get() int64 {
   157  	g.mutex.RLock()
   158  	defer g.mutex.RUnlock()
   159  	return g.val
   160  }
   161  
   162  // server implements metrics server functions.
   163  type server struct {
   164  	metricspb.UnimplementedMetricsServiceServer
   165  	mutex sync.RWMutex
   166  	// gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
   167  	gauges map[string]*gauge
   168  }
   169  
   170  // newMetricsServer returns a new metrics server.
   171  func newMetricsServer() *server {
   172  	return &server{gauges: make(map[string]*gauge)}
   173  }
   174  
   175  // GetAllGauges returns all gauges.
   176  func (s *server) GetAllGauges(_ *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
   177  	s.mutex.RLock()
   178  	defer s.mutex.RUnlock()
   179  
   180  	for name, gauge := range s.gauges {
   181  		if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
   182  			return err
   183  		}
   184  	}
   185  	return nil
   186  }
   187  
   188  // GetGauge returns the gauge for the given name.
   189  func (s *server) GetGauge(_ context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
   190  	s.mutex.RLock()
   191  	defer s.mutex.RUnlock()
   192  
   193  	if g, ok := s.gauges[in.Name]; ok {
   194  		return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
   195  	}
   196  	return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
   197  }
   198  
   199  // createGauge creates a gauge using the given name in metrics server.
   200  func (s *server) createGauge(name string) *gauge {
   201  	s.mutex.Lock()
   202  	defer s.mutex.Unlock()
   203  
   204  	if _, ok := s.gauges[name]; ok {
   205  		// gauge already exists.
   206  		panic(fmt.Sprintf("gauge %s already exists", name))
   207  	}
   208  	var g gauge
   209  	s.gauges[name] = &g
   210  	return &g
   211  }
   212  
   213  func startServer(server *server, port int) {
   214  	lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
   215  	if err != nil {
   216  		logger.Fatalf("failed to listen: %v", err)
   217  	}
   218  
   219  	s := grpc.NewServer()
   220  	metricspb.RegisterMetricsServiceServer(s, server)
   221  	s.Serve(lis)
   222  }
   223  
   224  // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
   225  func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
   226  	client := testgrpc.NewTestServiceClient(conn)
   227  	var numCalls int64
   228  	ctx := context.Background()
   229  	startTime := time.Now()
   230  	for {
   231  		test := selector.getNextTest()
   232  		switch test {
   233  		case "empty_unary":
   234  			interop.DoEmptyUnaryCall(ctx, client)
   235  		case "large_unary":
   236  			interop.DoLargeUnaryCall(ctx, client)
   237  		case "client_streaming":
   238  			interop.DoClientStreaming(ctx, client)
   239  		case "server_streaming":
   240  			interop.DoServerStreaming(ctx, client)
   241  		case "ping_pong":
   242  			interop.DoPingPong(ctx, client)
   243  		case "empty_stream":
   244  			interop.DoEmptyStream(ctx, client)
   245  		case "timeout_on_sleeping_server":
   246  			interop.DoTimeoutOnSleepingServer(ctx, client)
   247  		case "cancel_after_begin":
   248  			interop.DoCancelAfterBegin(ctx, client)
   249  		case "cancel_after_first_response":
   250  			interop.DoCancelAfterFirstResponse(ctx, client)
   251  		case "status_code_and_message":
   252  			interop.DoStatusCodeAndMessage(ctx, client)
   253  		case "custom_metadata":
   254  			interop.DoCustomMetadata(ctx, client)
   255  		}
   256  		numCalls++
   257  		defer func() { atomic.AddInt64(&totalNumCalls, numCalls) }()
   258  		gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
   259  
   260  		select {
   261  		case <-stop:
   262  			return
   263  		default:
   264  		}
   265  	}
   266  }
   267  
   268  func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
   269  	logger.Infof("server_addresses: %s", *serverAddresses)
   270  	logger.Infof("test_cases: %s", *testCases)
   271  	logger.Infof("test_duration_secs: %d", *testDurationSecs)
   272  	logger.Infof("num_channels_per_server: %d", *numChannelsPerServer)
   273  	logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
   274  	logger.Infof("metrics_port: %d", *metricsPort)
   275  	logger.Infof("use_tls: %t", *useTLS)
   276  	logger.Infof("use_test_ca: %t", *testCA)
   277  	logger.Infof("server_host_override: %s", *tlsServerName)
   278  	logger.Infof("custom_credentials_type: %s", *customCredentialsType)
   279  
   280  	logger.Infoln("addresses:")
   281  	for i, addr := range addresses {
   282  		logger.Infof("%d. %s\n", i+1, addr)
   283  	}
   284  	logger.Infoln("tests:")
   285  	for i, test := range tests {
   286  		logger.Infof("%d. %v\n", i+1, test)
   287  	}
   288  }
   289  
   290  func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
   291  	var opts []grpc.DialOption
   292  	if *customCredentialsType != "" {
   293  		if *customCredentialsType == googleDefaultCredsName {
   294  			opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials()))
   295  		} else if *customCredentialsType == computeEngineCredsName {
   296  			opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()))
   297  		} else {
   298  			logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType)
   299  		}
   300  	} else if useTLS {
   301  		var sn string
   302  		if tlsServerName != "" {
   303  			sn = tlsServerName
   304  		}
   305  		var creds credentials.TransportCredentials
   306  		if testCA {
   307  			var err error
   308  			if *caFile == "" {
   309  				*caFile = testdata.Path("x509/server_ca_cert.pem")
   310  			}
   311  			creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
   312  			if err != nil {
   313  				logger.Fatalf("Failed to create TLS credentials: %v", err)
   314  			}
   315  		} else {
   316  			creds = credentials.NewClientTLSFromCert(nil, sn)
   317  		}
   318  		opts = append(opts, grpc.WithTransportCredentials(creds))
   319  	} else {
   320  		opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
   321  	}
   322  	return grpc.NewClient(address, opts...)
   323  }
   324  
   325  func main() {
   326  	flag.Parse()
   327  	resolver.SetDefaultScheme("dns")
   328  	addresses := strings.Split(*serverAddresses, ",")
   329  	tests := parseTestCases(*testCases)
   330  	logParameterInfo(addresses, tests)
   331  	testSelector := newWeightedRandomTestSelector(tests)
   332  	metricsServer := newMetricsServer()
   333  
   334  	var wg sync.WaitGroup
   335  	wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
   336  	stop := make(chan bool)
   337  
   338  	for serverIndex, address := range addresses {
   339  		for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
   340  			conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
   341  			if err != nil {
   342  				logger.Fatalf("Fail to dial: %v", err)
   343  			}
   344  			defer conn.Close()
   345  			for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
   346  				name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
   347  				go func() {
   348  					defer wg.Done()
   349  					g := metricsServer.createGauge(name)
   350  					performRPCs(g, conn, testSelector, stop)
   351  				}()
   352  			}
   353  
   354  		}
   355  	}
   356  	go startServer(metricsServer, *metricsPort)
   357  	if *testDurationSecs > 0 {
   358  		time.Sleep(time.Duration(*testDurationSecs) * time.Second)
   359  		close(stop)
   360  	}
   361  	wg.Wait()
   362  	fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls)
   363  	logger.Infof(" ===== ALL DONE ===== ")
   364  }