gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/stress/client/main.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // client starts an interop client to do stress test and a metrics server to report qps.
    20  package main
    21  
    22  import (
    23  	"context"
    24  	"flag"
    25  	"fmt"
    26  	"math/rand"
    27  	"net"
    28  	"strconv"
    29  	"strings"
    30  	"sync"
    31  	"time"
    32  
    33  	grpc "gitee.com/ks-custle/core-gm/grpc"
    34  	"gitee.com/ks-custle/core-gm/grpc/codes"
    35  	"gitee.com/ks-custle/core-gm/grpc/credentials"
    36  	"gitee.com/ks-custle/core-gm/grpc/grpclog"
    37  	"gitee.com/ks-custle/core-gm/grpc/interop"
    38  	"gitee.com/ks-custle/core-gm/grpc/status"
    39  	"gitee.com/ks-custle/core-gm/grpc/testdata"
    40  
    41  	testgrpc "gitee.com/ks-custle/core-gm/grpc/interop/grpc_testing"
    42  	metricspb "gitee.com/ks-custle/core-gm/grpc/stress/grpc_testing"
    43  )
    44  
    45  var (
    46  	serverAddresses      = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
    47  	testCases            = flag.String("test_cases", "", "a list of test cases along with the relative weights")
    48  	testDurationSecs     = flag.Int("test_duration_secs", -1, "test duration in seconds")
    49  	numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
    50  	numStubsPerChannel   = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
    51  	metricsPort          = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
    52  	useTLS               = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
    53  	testCA               = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
    54  	tlsServerName        = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
    55  	caFile               = flag.String("ca_file", "", "The file containing the CA root cert file")
    56  
    57  	logger = grpclog.Component("stress")
    58  )
    59  
    60  // testCaseWithWeight contains the test case type and its weight.
    61  type testCaseWithWeight struct {
    62  	name   string
    63  	weight int
    64  }
    65  
    66  // parseTestCases converts test case string to a list of struct testCaseWithWeight.
    67  func parseTestCases(testCaseString string) []testCaseWithWeight {
    68  	testCaseStrings := strings.Split(testCaseString, ",")
    69  	testCases := make([]testCaseWithWeight, len(testCaseStrings))
    70  	for i, str := range testCaseStrings {
    71  		testCase := strings.Split(str, ":")
    72  		if len(testCase) != 2 {
    73  			panic(fmt.Sprintf("invalid test case with weight: %s", str))
    74  		}
    75  		// Check if test case is supported.
    76  		switch testCase[0] {
    77  		case
    78  			"empty_unary",
    79  			"large_unary",
    80  			"client_streaming",
    81  			"server_streaming",
    82  			"ping_pong",
    83  			"empty_stream",
    84  			"timeout_on_sleeping_server",
    85  			"cancel_after_begin",
    86  			"cancel_after_first_response",
    87  			"status_code_and_message",
    88  			"custom_metadata":
    89  		default:
    90  			panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
    91  		}
    92  		testCases[i].name = testCase[0]
    93  		w, err := strconv.Atoi(testCase[1])
    94  		if err != nil {
    95  			panic(fmt.Sprintf("%v", err))
    96  		}
    97  		testCases[i].weight = w
    98  	}
    99  	return testCases
   100  }
   101  
   102  // weightedRandomTestSelector defines a weighted random selector for test case types.
   103  type weightedRandomTestSelector struct {
   104  	tests       []testCaseWithWeight
   105  	totalWeight int
   106  }
   107  
   108  // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
   109  func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
   110  	var totalWeight int
   111  	for _, t := range tests {
   112  		totalWeight += t.weight
   113  	}
   114  	rand.Seed(time.Now().UnixNano())
   115  	return &weightedRandomTestSelector{tests, totalWeight}
   116  }
   117  
   118  func (selector weightedRandomTestSelector) getNextTest() string {
   119  	random := rand.Intn(selector.totalWeight)
   120  	var weightSofar int
   121  	for _, test := range selector.tests {
   122  		weightSofar += test.weight
   123  		if random < weightSofar {
   124  			return test.name
   125  		}
   126  	}
   127  	panic("no test case selected by weightedRandomTestSelector")
   128  }
   129  
   130  // gauge stores the qps of one interop client (one stub).
   131  type gauge struct {
   132  	mutex sync.RWMutex
   133  	val   int64
   134  }
   135  
   136  func (g *gauge) set(v int64) {
   137  	g.mutex.Lock()
   138  	defer g.mutex.Unlock()
   139  	g.val = v
   140  }
   141  
   142  func (g *gauge) get() int64 {
   143  	g.mutex.RLock()
   144  	defer g.mutex.RUnlock()
   145  	return g.val
   146  }
   147  
   148  // server implements metrics server functions.
   149  type server struct {
   150  	metricspb.UnimplementedMetricsServiceServer
   151  	mutex sync.RWMutex
   152  	// gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
   153  	gauges map[string]*gauge
   154  }
   155  
   156  // newMetricsServer returns a new metrics server.
   157  func newMetricsServer() *server {
   158  	return &server{gauges: make(map[string]*gauge)}
   159  }
   160  
   161  // GetAllGauges returns all gauges.
   162  func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
   163  	s.mutex.RLock()
   164  	defer s.mutex.RUnlock()
   165  
   166  	for name, gauge := range s.gauges {
   167  		if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
   168  			return err
   169  		}
   170  	}
   171  	return nil
   172  }
   173  
   174  // GetGauge returns the gauge for the given name.
   175  func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
   176  	s.mutex.RLock()
   177  	defer s.mutex.RUnlock()
   178  
   179  	if g, ok := s.gauges[in.Name]; ok {
   180  		return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
   181  	}
   182  	return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
   183  }
   184  
   185  // createGauge creates a gauge using the given name in metrics server.
   186  func (s *server) createGauge(name string) *gauge {
   187  	s.mutex.Lock()
   188  	defer s.mutex.Unlock()
   189  
   190  	if _, ok := s.gauges[name]; ok {
   191  		// gauge already exists.
   192  		panic(fmt.Sprintf("gauge %s already exists", name))
   193  	}
   194  	var g gauge
   195  	s.gauges[name] = &g
   196  	return &g
   197  }
   198  
   199  func startServer(server *server, port int) {
   200  	lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
   201  	if err != nil {
   202  		logger.Fatalf("failed to listen: %v", err)
   203  	}
   204  
   205  	s := grpc.NewServer()
   206  	metricspb.RegisterMetricsServiceServer(s, server)
   207  	s.Serve(lis)
   208  
   209  }
   210  
   211  // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
   212  func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
   213  	client := testgrpc.NewTestServiceClient(conn)
   214  	var numCalls int64
   215  	startTime := time.Now()
   216  	for {
   217  		test := selector.getNextTest()
   218  		switch test {
   219  		case "empty_unary":
   220  			interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true))
   221  		case "large_unary":
   222  			interop.DoLargeUnaryCall(client, grpc.WaitForReady(true))
   223  		case "client_streaming":
   224  			interop.DoClientStreaming(client, grpc.WaitForReady(true))
   225  		case "server_streaming":
   226  			interop.DoServerStreaming(client, grpc.WaitForReady(true))
   227  		case "ping_pong":
   228  			interop.DoPingPong(client, grpc.WaitForReady(true))
   229  		case "empty_stream":
   230  			interop.DoEmptyStream(client, grpc.WaitForReady(true))
   231  		case "timeout_on_sleeping_server":
   232  			interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true))
   233  		case "cancel_after_begin":
   234  			interop.DoCancelAfterBegin(client, grpc.WaitForReady(true))
   235  		case "cancel_after_first_response":
   236  			interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true))
   237  		case "status_code_and_message":
   238  			interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true))
   239  		case "custom_metadata":
   240  			interop.DoCustomMetadata(client, grpc.WaitForReady(true))
   241  		}
   242  		numCalls++
   243  		gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
   244  
   245  		select {
   246  		case <-stop:
   247  			return
   248  		default:
   249  		}
   250  	}
   251  }
   252  
   253  func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
   254  	logger.Infof("server_addresses: %s", *serverAddresses)
   255  	logger.Infof("test_cases: %s", *testCases)
   256  	logger.Infof("test_duration_secs: %d", *testDurationSecs)
   257  	logger.Infof("num_channels_per_server: %d", *numChannelsPerServer)
   258  	logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
   259  	logger.Infof("metrics_port: %d", *metricsPort)
   260  	logger.Infof("use_tls: %t", *useTLS)
   261  	logger.Infof("use_test_ca: %t", *testCA)
   262  	logger.Infof("server_host_override: %s", *tlsServerName)
   263  
   264  	logger.Infoln("addresses:")
   265  	for i, addr := range addresses {
   266  		logger.Infof("%d. %s\n", i+1, addr)
   267  	}
   268  	logger.Infoln("tests:")
   269  	for i, test := range tests {
   270  		logger.Infof("%d. %v\n", i+1, test)
   271  	}
   272  }
   273  
   274  func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
   275  	var opts []grpc.DialOption
   276  	if useTLS {
   277  		var sn string
   278  		if tlsServerName != "" {
   279  			sn = tlsServerName
   280  		}
   281  		var creds credentials.TransportCredentials
   282  		if testCA {
   283  			var err error
   284  			if *caFile == "" {
   285  				*caFile = testdata.Path("x509/server_ca_cert.pem")
   286  			}
   287  			creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
   288  			if err != nil {
   289  				logger.Fatalf("Failed to create TLS credentials %v", err)
   290  			}
   291  		} else {
   292  			creds = credentials.NewClientTLSFromCert(nil, sn)
   293  		}
   294  		opts = append(opts, grpc.WithTransportCredentials(creds))
   295  	} else {
   296  		opts = append(opts, grpc.WithInsecure())
   297  	}
   298  	return grpc.Dial(address, opts...)
   299  }
   300  
   301  func main() {
   302  	flag.Parse()
   303  	addresses := strings.Split(*serverAddresses, ",")
   304  	tests := parseTestCases(*testCases)
   305  	logParameterInfo(addresses, tests)
   306  	testSelector := newWeightedRandomTestSelector(tests)
   307  	metricsServer := newMetricsServer()
   308  
   309  	var wg sync.WaitGroup
   310  	wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
   311  	stop := make(chan bool)
   312  
   313  	for serverIndex, address := range addresses {
   314  		for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
   315  			conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
   316  			if err != nil {
   317  				logger.Fatalf("Fail to dial: %v", err)
   318  			}
   319  			defer conn.Close()
   320  			for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
   321  				name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
   322  				go func() {
   323  					defer wg.Done()
   324  					g := metricsServer.createGauge(name)
   325  					performRPCs(g, conn, testSelector, stop)
   326  				}()
   327  			}
   328  
   329  		}
   330  	}
   331  	go startServer(metricsServer, *metricsPort)
   332  	if *testDurationSecs > 0 {
   333  		time.Sleep(time.Duration(*testDurationSecs) * time.Second)
   334  		close(stop)
   335  	}
   336  	wg.Wait()
   337  	logger.Infof(" ===== ALL DONE ===== ")
   338  
   339  }