google.golang.org/grpc@v1.62.1/interop/stress/client/main.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 // client starts an interop client to do stress test and a metrics server to report qps. 20 package main 21 22 import ( 23 "context" 24 "flag" 25 "fmt" 26 "math/rand" 27 "net" 28 "os" 29 "strconv" 30 "strings" 31 "sync" 32 "sync/atomic" 33 "time" 34 35 "google.golang.org/grpc" 36 "google.golang.org/grpc/codes" 37 "google.golang.org/grpc/credentials" 38 "google.golang.org/grpc/credentials/google" 39 "google.golang.org/grpc/credentials/insecure" 40 "google.golang.org/grpc/grpclog" 41 "google.golang.org/grpc/interop" 42 "google.golang.org/grpc/resolver" 43 "google.golang.org/grpc/status" 44 "google.golang.org/grpc/testdata" 45 46 _ "google.golang.org/grpc/xds/googledirectpath" // Register xDS resolver required for c2p directpath. 47 48 testgrpc "google.golang.org/grpc/interop/grpc_testing" 49 metricspb "google.golang.org/grpc/interop/stress/grpc_testing" 50 ) 51 52 const ( 53 googleDefaultCredsName = "google_default_credentials" 54 computeEngineCredsName = "compute_engine_channel_creds" 55 ) 56 57 var ( 58 serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses") 59 testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights") 60 testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds") 61 numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") 62 numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") 63 metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") 64 useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") 65 testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") 66 tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") 67 caFile = flag.String("ca_file", "", "The file containing the CA root cert file") 68 customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use") 69 70 totalNumCalls int64 71 logger = grpclog.Component("stress") 72 ) 73 74 // testCaseWithWeight contains the test case type and its weight. 75 type testCaseWithWeight struct { 76 name string 77 weight int 78 } 79 80 // parseTestCases converts test case string to a list of struct testCaseWithWeight. 81 func parseTestCases(testCaseString string) []testCaseWithWeight { 82 testCaseStrings := strings.Split(testCaseString, ",") 83 testCases := make([]testCaseWithWeight, len(testCaseStrings)) 84 for i, str := range testCaseStrings { 85 testCaseNameAndWeight := strings.Split(str, ":") 86 if len(testCaseNameAndWeight) != 2 { 87 panic(fmt.Sprintf("invalid test case with weight: %s", str)) 88 } 89 // Check if test case is supported. 90 testCaseName := strings.ToLower(testCaseNameAndWeight[0]) 91 switch testCaseName { 92 case 93 "empty_unary", 94 "large_unary", 95 "client_streaming", 96 "server_streaming", 97 "ping_pong", 98 "empty_stream", 99 "timeout_on_sleeping_server", 100 "cancel_after_begin", 101 "cancel_after_first_response", 102 "status_code_and_message", 103 "custom_metadata": 104 default: 105 panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0])) 106 } 107 testCases[i].name = testCaseName 108 w, err := strconv.Atoi(testCaseNameAndWeight[1]) 109 if err != nil { 110 panic(fmt.Sprintf("%v", err)) 111 } 112 testCases[i].weight = w 113 } 114 return testCases 115 } 116 117 // weightedRandomTestSelector defines a weighted random selector for test case types. 118 type weightedRandomTestSelector struct { 119 tests []testCaseWithWeight 120 totalWeight int 121 } 122 123 // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight. 124 func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector { 125 var totalWeight int 126 for _, t := range tests { 127 totalWeight += t.weight 128 } 129 rand.Seed(time.Now().UnixNano()) 130 return &weightedRandomTestSelector{tests, totalWeight} 131 } 132 133 func (selector weightedRandomTestSelector) getNextTest() string { 134 random := rand.Intn(selector.totalWeight) 135 var weightSofar int 136 for _, test := range selector.tests { 137 weightSofar += test.weight 138 if random < weightSofar { 139 return test.name 140 } 141 } 142 panic("no test case selected by weightedRandomTestSelector") 143 } 144 145 // gauge stores the qps of one interop client (one stub). 146 type gauge struct { 147 mutex sync.RWMutex 148 val int64 149 } 150 151 func (g *gauge) set(v int64) { 152 g.mutex.Lock() 153 defer g.mutex.Unlock() 154 g.val = v 155 } 156 157 func (g *gauge) get() int64 { 158 g.mutex.RLock() 159 defer g.mutex.RUnlock() 160 return g.val 161 } 162 163 // server implements metrics server functions. 164 type server struct { 165 metricspb.UnimplementedMetricsServiceServer 166 mutex sync.RWMutex 167 // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge. 168 gauges map[string]*gauge 169 } 170 171 // newMetricsServer returns a new metrics server. 172 func newMetricsServer() *server { 173 return &server{gauges: make(map[string]*gauge)} 174 } 175 176 // GetAllGauges returns all gauges. 177 func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error { 178 s.mutex.RLock() 179 defer s.mutex.RUnlock() 180 181 for name, gauge := range s.gauges { 182 if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil { 183 return err 184 } 185 } 186 return nil 187 } 188 189 // GetGauge returns the gauge for the given name. 190 func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) { 191 s.mutex.RLock() 192 defer s.mutex.RUnlock() 193 194 if g, ok := s.gauges[in.Name]; ok { 195 return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil 196 } 197 return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) 198 } 199 200 // createGauge creates a gauge using the given name in metrics server. 201 func (s *server) createGauge(name string) *gauge { 202 s.mutex.Lock() 203 defer s.mutex.Unlock() 204 205 if _, ok := s.gauges[name]; ok { 206 // gauge already exists. 207 panic(fmt.Sprintf("gauge %s already exists", name)) 208 } 209 var g gauge 210 s.gauges[name] = &g 211 return &g 212 } 213 214 func startServer(server *server, port int) { 215 lis, err := net.Listen("tcp", ":"+strconv.Itoa(port)) 216 if err != nil { 217 logger.Fatalf("failed to listen: %v", err) 218 } 219 220 s := grpc.NewServer() 221 metricspb.RegisterMetricsServiceServer(s, server) 222 s.Serve(lis) 223 } 224 225 // performRPCs uses weightedRandomTestSelector to select test case and runs the tests. 226 func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) { 227 client := testgrpc.NewTestServiceClient(conn) 228 var numCalls int64 229 ctx := context.Background() 230 startTime := time.Now() 231 for { 232 test := selector.getNextTest() 233 switch test { 234 case "empty_unary": 235 interop.DoEmptyUnaryCall(ctx, client) 236 case "large_unary": 237 interop.DoLargeUnaryCall(ctx, client) 238 case "client_streaming": 239 interop.DoClientStreaming(ctx, client) 240 case "server_streaming": 241 interop.DoServerStreaming(ctx, client) 242 case "ping_pong": 243 interop.DoPingPong(ctx, client) 244 case "empty_stream": 245 interop.DoEmptyStream(ctx, client) 246 case "timeout_on_sleeping_server": 247 interop.DoTimeoutOnSleepingServer(ctx, client) 248 case "cancel_after_begin": 249 interop.DoCancelAfterBegin(ctx, client) 250 case "cancel_after_first_response": 251 interop.DoCancelAfterFirstResponse(ctx, client) 252 case "status_code_and_message": 253 interop.DoStatusCodeAndMessage(ctx, client) 254 case "custom_metadata": 255 interop.DoCustomMetadata(ctx, client) 256 } 257 numCalls++ 258 defer func() { atomic.AddInt64(&totalNumCalls, numCalls) }() 259 gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds())) 260 261 select { 262 case <-stop: 263 return 264 default: 265 } 266 } 267 } 268 269 func logParameterInfo(addresses []string, tests []testCaseWithWeight) { 270 logger.Infof("server_addresses: %s", *serverAddresses) 271 logger.Infof("test_cases: %s", *testCases) 272 logger.Infof("test_duration_secs: %d", *testDurationSecs) 273 logger.Infof("num_channels_per_server: %d", *numChannelsPerServer) 274 logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel) 275 logger.Infof("metrics_port: %d", *metricsPort) 276 logger.Infof("use_tls: %t", *useTLS) 277 logger.Infof("use_test_ca: %t", *testCA) 278 logger.Infof("server_host_override: %s", *tlsServerName) 279 logger.Infof("custom_credentials_type: %s", *customCredentialsType) 280 281 logger.Infoln("addresses:") 282 for i, addr := range addresses { 283 logger.Infof("%d. %s\n", i+1, addr) 284 } 285 logger.Infoln("tests:") 286 for i, test := range tests { 287 logger.Infof("%d. %v\n", i+1, test) 288 } 289 } 290 291 func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) { 292 var opts []grpc.DialOption 293 if *customCredentialsType != "" { 294 if *customCredentialsType == googleDefaultCredsName { 295 opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials())) 296 } else if *customCredentialsType == computeEngineCredsName { 297 opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials())) 298 } else { 299 logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType) 300 } 301 } else if useTLS { 302 var sn string 303 if tlsServerName != "" { 304 sn = tlsServerName 305 } 306 var creds credentials.TransportCredentials 307 if testCA { 308 var err error 309 if *caFile == "" { 310 *caFile = testdata.Path("x509/server_ca_cert.pem") 311 } 312 creds, err = credentials.NewClientTLSFromFile(*caFile, sn) 313 if err != nil { 314 logger.Fatalf("Failed to create TLS credentials: %v", err) 315 } 316 } else { 317 creds = credentials.NewClientTLSFromCert(nil, sn) 318 } 319 opts = append(opts, grpc.WithTransportCredentials(creds)) 320 } else { 321 opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) 322 } 323 return grpc.Dial(address, opts...) 324 } 325 326 func main() { 327 flag.Parse() 328 resolver.SetDefaultScheme("dns") 329 addresses := strings.Split(*serverAddresses, ",") 330 tests := parseTestCases(*testCases) 331 logParameterInfo(addresses, tests) 332 testSelector := newWeightedRandomTestSelector(tests) 333 metricsServer := newMetricsServer() 334 335 var wg sync.WaitGroup 336 wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) 337 stop := make(chan bool) 338 339 for serverIndex, address := range addresses { 340 for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { 341 conn, err := newConn(address, *useTLS, *testCA, *tlsServerName) 342 if err != nil { 343 logger.Fatalf("Fail to dial: %v", err) 344 } 345 defer conn.Close() 346 for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ { 347 name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1) 348 go func() { 349 defer wg.Done() 350 g := metricsServer.createGauge(name) 351 performRPCs(g, conn, testSelector, stop) 352 }() 353 } 354 355 } 356 } 357 go startServer(metricsServer, *metricsPort) 358 if *testDurationSecs > 0 { 359 time.Sleep(time.Duration(*testDurationSecs) * time.Second) 360 close(stop) 361 } 362 wg.Wait() 363 fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls) 364 logger.Infof(" ===== ALL DONE ===== ") 365 }