gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/stress/client/main.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 // client starts an interop client to do stress test and a metrics server to report qps. 20 package main 21 22 import ( 23 "context" 24 "flag" 25 "fmt" 26 "math/rand" 27 "net" 28 "strconv" 29 "strings" 30 "sync" 31 "time" 32 33 grpc "gitee.com/ks-custle/core-gm/grpc" 34 "gitee.com/ks-custle/core-gm/grpc/codes" 35 "gitee.com/ks-custle/core-gm/grpc/credentials" 36 "gitee.com/ks-custle/core-gm/grpc/grpclog" 37 "gitee.com/ks-custle/core-gm/grpc/interop" 38 "gitee.com/ks-custle/core-gm/grpc/status" 39 "gitee.com/ks-custle/core-gm/grpc/testdata" 40 41 testgrpc "gitee.com/ks-custle/core-gm/grpc/interop/grpc_testing" 42 metricspb "gitee.com/ks-custle/core-gm/grpc/stress/grpc_testing" 43 ) 44 45 var ( 46 serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses") 47 testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights") 48 testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds") 49 numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") 50 numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") 51 metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") 52 useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") 53 testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") 54 tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") 55 caFile = flag.String("ca_file", "", "The file containing the CA root cert file") 56 57 logger = grpclog.Component("stress") 58 ) 59 60 // testCaseWithWeight contains the test case type and its weight. 61 type testCaseWithWeight struct { 62 name string 63 weight int 64 } 65 66 // parseTestCases converts test case string to a list of struct testCaseWithWeight. 67 func parseTestCases(testCaseString string) []testCaseWithWeight { 68 testCaseStrings := strings.Split(testCaseString, ",") 69 testCases := make([]testCaseWithWeight, len(testCaseStrings)) 70 for i, str := range testCaseStrings { 71 testCase := strings.Split(str, ":") 72 if len(testCase) != 2 { 73 panic(fmt.Sprintf("invalid test case with weight: %s", str)) 74 } 75 // Check if test case is supported. 76 switch testCase[0] { 77 case 78 "empty_unary", 79 "large_unary", 80 "client_streaming", 81 "server_streaming", 82 "ping_pong", 83 "empty_stream", 84 "timeout_on_sleeping_server", 85 "cancel_after_begin", 86 "cancel_after_first_response", 87 "status_code_and_message", 88 "custom_metadata": 89 default: 90 panic(fmt.Sprintf("unknown test type: %s", testCase[0])) 91 } 92 testCases[i].name = testCase[0] 93 w, err := strconv.Atoi(testCase[1]) 94 if err != nil { 95 panic(fmt.Sprintf("%v", err)) 96 } 97 testCases[i].weight = w 98 } 99 return testCases 100 } 101 102 // weightedRandomTestSelector defines a weighted random selector for test case types. 103 type weightedRandomTestSelector struct { 104 tests []testCaseWithWeight 105 totalWeight int 106 } 107 108 // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight. 109 func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector { 110 var totalWeight int 111 for _, t := range tests { 112 totalWeight += t.weight 113 } 114 rand.Seed(time.Now().UnixNano()) 115 return &weightedRandomTestSelector{tests, totalWeight} 116 } 117 118 func (selector weightedRandomTestSelector) getNextTest() string { 119 random := rand.Intn(selector.totalWeight) 120 var weightSofar int 121 for _, test := range selector.tests { 122 weightSofar += test.weight 123 if random < weightSofar { 124 return test.name 125 } 126 } 127 panic("no test case selected by weightedRandomTestSelector") 128 } 129 130 // gauge stores the qps of one interop client (one stub). 131 type gauge struct { 132 mutex sync.RWMutex 133 val int64 134 } 135 136 func (g *gauge) set(v int64) { 137 g.mutex.Lock() 138 defer g.mutex.Unlock() 139 g.val = v 140 } 141 142 func (g *gauge) get() int64 { 143 g.mutex.RLock() 144 defer g.mutex.RUnlock() 145 return g.val 146 } 147 148 // server implements metrics server functions. 149 type server struct { 150 metricspb.UnimplementedMetricsServiceServer 151 mutex sync.RWMutex 152 // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge. 153 gauges map[string]*gauge 154 } 155 156 // newMetricsServer returns a new metrics server. 157 func newMetricsServer() *server { 158 return &server{gauges: make(map[string]*gauge)} 159 } 160 161 // GetAllGauges returns all gauges. 162 func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error { 163 s.mutex.RLock() 164 defer s.mutex.RUnlock() 165 166 for name, gauge := range s.gauges { 167 if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil { 168 return err 169 } 170 } 171 return nil 172 } 173 174 // GetGauge returns the gauge for the given name. 175 func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) { 176 s.mutex.RLock() 177 defer s.mutex.RUnlock() 178 179 if g, ok := s.gauges[in.Name]; ok { 180 return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil 181 } 182 return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) 183 } 184 185 // createGauge creates a gauge using the given name in metrics server. 186 func (s *server) createGauge(name string) *gauge { 187 s.mutex.Lock() 188 defer s.mutex.Unlock() 189 190 if _, ok := s.gauges[name]; ok { 191 // gauge already exists. 192 panic(fmt.Sprintf("gauge %s already exists", name)) 193 } 194 var g gauge 195 s.gauges[name] = &g 196 return &g 197 } 198 199 func startServer(server *server, port int) { 200 lis, err := net.Listen("tcp", ":"+strconv.Itoa(port)) 201 if err != nil { 202 logger.Fatalf("failed to listen: %v", err) 203 } 204 205 s := grpc.NewServer() 206 metricspb.RegisterMetricsServiceServer(s, server) 207 s.Serve(lis) 208 209 } 210 211 // performRPCs uses weightedRandomTestSelector to select test case and runs the tests. 212 func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) { 213 client := testgrpc.NewTestServiceClient(conn) 214 var numCalls int64 215 startTime := time.Now() 216 for { 217 test := selector.getNextTest() 218 switch test { 219 case "empty_unary": 220 interop.DoEmptyUnaryCall(client, grpc.WaitForReady(true)) 221 case "large_unary": 222 interop.DoLargeUnaryCall(client, grpc.WaitForReady(true)) 223 case "client_streaming": 224 interop.DoClientStreaming(client, grpc.WaitForReady(true)) 225 case "server_streaming": 226 interop.DoServerStreaming(client, grpc.WaitForReady(true)) 227 case "ping_pong": 228 interop.DoPingPong(client, grpc.WaitForReady(true)) 229 case "empty_stream": 230 interop.DoEmptyStream(client, grpc.WaitForReady(true)) 231 case "timeout_on_sleeping_server": 232 interop.DoTimeoutOnSleepingServer(client, grpc.WaitForReady(true)) 233 case "cancel_after_begin": 234 interop.DoCancelAfterBegin(client, grpc.WaitForReady(true)) 235 case "cancel_after_first_response": 236 interop.DoCancelAfterFirstResponse(client, grpc.WaitForReady(true)) 237 case "status_code_and_message": 238 interop.DoStatusCodeAndMessage(client, grpc.WaitForReady(true)) 239 case "custom_metadata": 240 interop.DoCustomMetadata(client, grpc.WaitForReady(true)) 241 } 242 numCalls++ 243 gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds())) 244 245 select { 246 case <-stop: 247 return 248 default: 249 } 250 } 251 } 252 253 func logParameterInfo(addresses []string, tests []testCaseWithWeight) { 254 logger.Infof("server_addresses: %s", *serverAddresses) 255 logger.Infof("test_cases: %s", *testCases) 256 logger.Infof("test_duration_secs: %d", *testDurationSecs) 257 logger.Infof("num_channels_per_server: %d", *numChannelsPerServer) 258 logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel) 259 logger.Infof("metrics_port: %d", *metricsPort) 260 logger.Infof("use_tls: %t", *useTLS) 261 logger.Infof("use_test_ca: %t", *testCA) 262 logger.Infof("server_host_override: %s", *tlsServerName) 263 264 logger.Infoln("addresses:") 265 for i, addr := range addresses { 266 logger.Infof("%d. %s\n", i+1, addr) 267 } 268 logger.Infoln("tests:") 269 for i, test := range tests { 270 logger.Infof("%d. %v\n", i+1, test) 271 } 272 } 273 274 func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) { 275 var opts []grpc.DialOption 276 if useTLS { 277 var sn string 278 if tlsServerName != "" { 279 sn = tlsServerName 280 } 281 var creds credentials.TransportCredentials 282 if testCA { 283 var err error 284 if *caFile == "" { 285 *caFile = testdata.Path("x509/server_ca_cert.pem") 286 } 287 creds, err = credentials.NewClientTLSFromFile(*caFile, sn) 288 if err != nil { 289 logger.Fatalf("Failed to create TLS credentials %v", err) 290 } 291 } else { 292 creds = credentials.NewClientTLSFromCert(nil, sn) 293 } 294 opts = append(opts, grpc.WithTransportCredentials(creds)) 295 } else { 296 opts = append(opts, grpc.WithInsecure()) 297 } 298 return grpc.Dial(address, opts...) 299 } 300 301 func main() { 302 flag.Parse() 303 addresses := strings.Split(*serverAddresses, ",") 304 tests := parseTestCases(*testCases) 305 logParameterInfo(addresses, tests) 306 testSelector := newWeightedRandomTestSelector(tests) 307 metricsServer := newMetricsServer() 308 309 var wg sync.WaitGroup 310 wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) 311 stop := make(chan bool) 312 313 for serverIndex, address := range addresses { 314 for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { 315 conn, err := newConn(address, *useTLS, *testCA, *tlsServerName) 316 if err != nil { 317 logger.Fatalf("Fail to dial: %v", err) 318 } 319 defer conn.Close() 320 for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ { 321 name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1) 322 go func() { 323 defer wg.Done() 324 g := metricsServer.createGauge(name) 325 performRPCs(g, conn, testSelector, stop) 326 }() 327 } 328 329 } 330 } 331 go startServer(metricsServer, *metricsPort) 332 if *testDurationSecs > 0 { 333 time.Sleep(time.Duration(*testDurationSecs) * time.Second) 334 close(stop) 335 } 336 wg.Wait() 337 logger.Infof(" ===== ALL DONE ===== ") 338 339 }