google.golang.org/grpc@v1.62.1/stats/stats_test.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package stats_test 20 21 import ( 22 "context" 23 "fmt" 24 "io" 25 "net" 26 "reflect" 27 "sync" 28 "testing" 29 "time" 30 31 "google.golang.org/grpc" 32 "google.golang.org/grpc/credentials/insecure" 33 "google.golang.org/grpc/internal" 34 "google.golang.org/grpc/internal/grpctest" 35 "google.golang.org/grpc/internal/stubserver" 36 "google.golang.org/grpc/internal/testutils" 37 "google.golang.org/grpc/metadata" 38 "google.golang.org/grpc/stats" 39 "google.golang.org/grpc/status" 40 "google.golang.org/protobuf/proto" 41 42 testgrpc "google.golang.org/grpc/interop/grpc_testing" 43 testpb "google.golang.org/grpc/interop/grpc_testing" 44 ) 45 46 const defaultTestTimeout = 10 * time.Second 47 48 type s struct { 49 grpctest.Tester 50 } 51 52 func Test(t *testing.T) { 53 grpctest.RunSubTests(t, s{}) 54 } 55 56 func init() { 57 grpc.EnableTracing = false 58 } 59 60 type connCtxKey struct{} 61 type rpcCtxKey struct{} 62 63 var ( 64 // For headers sent to server: 65 testMetadata = metadata.MD{ 66 "key1": []string{"value1"}, 67 "key2": []string{"value2"}, 68 "user-agent": []string{fmt.Sprintf("test/0.0.1 grpc-go/%s", grpc.Version)}, 69 } 70 // For headers sent from server: 71 testHeaderMetadata = metadata.MD{ 72 "hkey1": []string{"headerValue1"}, 73 "hkey2": []string{"headerValue2"}, 74 } 75 // For trailers sent from server: 76 testTrailerMetadata = metadata.MD{ 77 "tkey1": []string{"trailerValue1"}, 78 "tkey2": []string{"trailerValue2"}, 79 } 80 // The id for which the service handler should return error. 81 errorID int32 = 32202 82 ) 83 84 func idToPayload(id int32) *testpb.Payload { 85 return &testpb.Payload{Body: []byte{byte(id), byte(id >> 8), byte(id >> 16), byte(id >> 24)}} 86 } 87 88 func payloadToID(p *testpb.Payload) int32 { 89 if p == nil || len(p.Body) != 4 { 90 panic("invalid payload") 91 } 92 return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24 93 } 94 95 type testServer struct { 96 testgrpc.UnimplementedTestServiceServer 97 } 98 99 func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 100 if err := grpc.SendHeader(ctx, testHeaderMetadata); err != nil { 101 return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", testHeaderMetadata, err) 102 } 103 if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { 104 return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err) 105 } 106 107 if id := payloadToID(in.Payload); id == errorID { 108 return nil, fmt.Errorf("got error id: %v", id) 109 } 110 111 return &testpb.SimpleResponse{Payload: in.Payload}, nil 112 } 113 114 func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error { 115 if err := stream.SendHeader(testHeaderMetadata); err != nil { 116 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) 117 } 118 stream.SetTrailer(testTrailerMetadata) 119 for { 120 in, err := stream.Recv() 121 if err == io.EOF { 122 // read done. 123 return nil 124 } 125 if err != nil { 126 return err 127 } 128 129 if id := payloadToID(in.Payload); id == errorID { 130 return fmt.Errorf("got error id: %v", id) 131 } 132 133 if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil { 134 return err 135 } 136 } 137 } 138 139 func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error { 140 if err := stream.SendHeader(testHeaderMetadata); err != nil { 141 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) 142 } 143 stream.SetTrailer(testTrailerMetadata) 144 for { 145 in, err := stream.Recv() 146 if err == io.EOF { 147 // read done. 148 return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0}) 149 } 150 if err != nil { 151 return err 152 } 153 154 if id := payloadToID(in.Payload); id == errorID { 155 return fmt.Errorf("got error id: %v", id) 156 } 157 } 158 } 159 160 func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error { 161 if err := stream.SendHeader(testHeaderMetadata); err != nil { 162 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) 163 } 164 stream.SetTrailer(testTrailerMetadata) 165 166 if id := payloadToID(in.Payload); id == errorID { 167 return fmt.Errorf("got error id: %v", id) 168 } 169 170 for i := 0; i < 5; i++ { 171 if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil { 172 return err 173 } 174 } 175 return nil 176 } 177 178 // test is an end-to-end test. It should be created with the newTest 179 // func, modified as needed, and then started with its startServer method. 180 // It should be cleaned up with the tearDown method. 181 type test struct { 182 t *testing.T 183 compress string 184 clientStatsHandlers []stats.Handler 185 serverStatsHandlers []stats.Handler 186 187 testServer testgrpc.TestServiceServer // nil means none 188 // srv and srvAddr are set once startServer is called. 189 srv *grpc.Server 190 srvAddr string 191 192 cc *grpc.ClientConn // nil until requested via clientConn 193 } 194 195 func (te *test) tearDown() { 196 if te.cc != nil { 197 te.cc.Close() 198 te.cc = nil 199 } 200 te.srv.Stop() 201 } 202 203 type testConfig struct { 204 compress string 205 } 206 207 // newTest returns a new test using the provided testing.T and 208 // environment. It is returned with default values. Tests should 209 // modify it before calling its startServer and clientConn methods. 210 func newTest(t *testing.T, tc *testConfig, chs []stats.Handler, shs []stats.Handler) *test { 211 te := &test{ 212 t: t, 213 compress: tc.compress, 214 clientStatsHandlers: chs, 215 serverStatsHandlers: shs, 216 } 217 return te 218 } 219 220 // startServer starts a gRPC server listening. Callers should defer a 221 // call to te.tearDown to clean up. 222 func (te *test) startServer(ts testgrpc.TestServiceServer) { 223 te.testServer = ts 224 lis, err := net.Listen("tcp", "localhost:0") 225 if err != nil { 226 te.t.Fatalf("Failed to listen: %v", err) 227 } 228 var opts []grpc.ServerOption 229 if te.compress == "gzip" { 230 opts = append(opts, 231 grpc.RPCCompressor(grpc.NewGZIPCompressor()), 232 grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), 233 ) 234 } 235 for _, sh := range te.serverStatsHandlers { 236 opts = append(opts, grpc.StatsHandler(sh)) 237 } 238 s := grpc.NewServer(opts...) 239 te.srv = s 240 if te.testServer != nil { 241 testgrpc.RegisterTestServiceServer(s, te.testServer) 242 } 243 244 go s.Serve(lis) 245 te.srvAddr = lis.Addr().String() 246 } 247 248 func (te *test) clientConn() *grpc.ClientConn { 249 if te.cc != nil { 250 return te.cc 251 } 252 opts := []grpc.DialOption{ 253 grpc.WithTransportCredentials(insecure.NewCredentials()), 254 grpc.WithBlock(), 255 grpc.WithUserAgent("test/0.0.1"), 256 } 257 if te.compress == "gzip" { 258 opts = append(opts, 259 grpc.WithCompressor(grpc.NewGZIPCompressor()), 260 grpc.WithDecompressor(grpc.NewGZIPDecompressor()), 261 ) 262 } 263 for _, sh := range te.clientStatsHandlers { 264 opts = append(opts, grpc.WithStatsHandler(sh)) 265 } 266 267 var err error 268 te.cc, err = grpc.Dial(te.srvAddr, opts...) 269 if err != nil { 270 te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err) 271 } 272 return te.cc 273 } 274 275 type rpcType int 276 277 const ( 278 unaryRPC rpcType = iota 279 clientStreamRPC 280 serverStreamRPC 281 fullDuplexStreamRPC 282 ) 283 284 type rpcConfig struct { 285 count int // Number of requests and responses for streaming RPCs. 286 success bool // Whether the RPC should succeed or return error. 287 failfast bool 288 callType rpcType // Type of RPC. 289 } 290 291 func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { 292 var ( 293 resp *testpb.SimpleResponse 294 req *testpb.SimpleRequest 295 err error 296 ) 297 tc := testgrpc.NewTestServiceClient(te.clientConn()) 298 if c.success { 299 req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)} 300 } else { 301 req = &testpb.SimpleRequest{Payload: idToPayload(errorID)} 302 } 303 304 tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 305 defer cancel() 306 resp, err = tc.UnaryCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast)) 307 return req, resp, err 308 } 309 310 func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []proto.Message, error) { 311 var ( 312 reqs []proto.Message 313 resps []proto.Message 314 err error 315 ) 316 tc := testgrpc.NewTestServiceClient(te.clientConn()) 317 tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 318 defer cancel() 319 stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast)) 320 if err != nil { 321 return reqs, resps, err 322 } 323 var startID int32 324 if !c.success { 325 startID = errorID 326 } 327 for i := 0; i < c.count; i++ { 328 req := &testpb.StreamingOutputCallRequest{ 329 Payload: idToPayload(int32(i) + startID), 330 } 331 reqs = append(reqs, req) 332 if err = stream.Send(req); err != nil { 333 return reqs, resps, err 334 } 335 var resp *testpb.StreamingOutputCallResponse 336 if resp, err = stream.Recv(); err != nil { 337 return reqs, resps, err 338 } 339 resps = append(resps, resp) 340 } 341 if err = stream.CloseSend(); err != nil && err != io.EOF { 342 return reqs, resps, err 343 } 344 if _, err = stream.Recv(); err != io.EOF { 345 return reqs, resps, err 346 } 347 348 return reqs, resps, nil 349 } 350 351 func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.StreamingInputCallResponse, error) { 352 var ( 353 reqs []proto.Message 354 resp *testpb.StreamingInputCallResponse 355 err error 356 ) 357 tc := testgrpc.NewTestServiceClient(te.clientConn()) 358 tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 359 defer cancel() 360 stream, err := tc.StreamingInputCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast)) 361 if err != nil { 362 return reqs, resp, err 363 } 364 var startID int32 365 if !c.success { 366 startID = errorID 367 } 368 for i := 0; i < c.count; i++ { 369 req := &testpb.StreamingInputCallRequest{ 370 Payload: idToPayload(int32(i) + startID), 371 } 372 reqs = append(reqs, req) 373 if err = stream.Send(req); err != nil { 374 return reqs, resp, err 375 } 376 } 377 resp, err = stream.CloseAndRecv() 378 return reqs, resp, err 379 } 380 381 func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallRequest, []proto.Message, error) { 382 var ( 383 req *testpb.StreamingOutputCallRequest 384 resps []proto.Message 385 err error 386 ) 387 388 tc := testgrpc.NewTestServiceClient(te.clientConn()) 389 390 var startID int32 391 if !c.success { 392 startID = errorID 393 } 394 req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)} 395 tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 396 defer cancel() 397 stream, err := tc.StreamingOutputCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast)) 398 if err != nil { 399 return req, resps, err 400 } 401 for { 402 var resp *testpb.StreamingOutputCallResponse 403 resp, err := stream.Recv() 404 if err == io.EOF { 405 return req, resps, nil 406 } else if err != nil { 407 return req, resps, err 408 } 409 resps = append(resps, resp) 410 } 411 } 412 413 type expectedData struct { 414 method string 415 isClientStream bool 416 isServerStream bool 417 serverAddr string 418 compression string 419 reqIdx int 420 requests []proto.Message 421 respIdx int 422 responses []proto.Message 423 err error 424 failfast bool 425 } 426 427 type gotData struct { 428 ctx context.Context 429 client bool 430 s any // This could be RPCStats or ConnStats. 431 } 432 433 const ( 434 begin int = iota 435 end 436 inPayload 437 inHeader 438 inTrailer 439 outPayload 440 outHeader 441 // TODO: test outTrailer ? 442 connBegin 443 connEnd 444 ) 445 446 func checkBegin(t *testing.T, d *gotData, e *expectedData) { 447 var ( 448 ok bool 449 st *stats.Begin 450 ) 451 if st, ok = d.s.(*stats.Begin); !ok { 452 t.Fatalf("got %T, want Begin", d.s) 453 } 454 if d.ctx == nil { 455 t.Fatalf("d.ctx = nil, want <non-nil>") 456 } 457 if st.BeginTime.IsZero() { 458 t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime) 459 } 460 if d.client { 461 if st.FailFast != e.failfast { 462 t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) 463 } 464 } 465 if st.IsClientStream != e.isClientStream { 466 t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream) 467 } 468 if st.IsServerStream != e.isServerStream { 469 t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream) 470 } 471 } 472 473 func checkInHeader(t *testing.T, d *gotData, e *expectedData) { 474 var ( 475 ok bool 476 st *stats.InHeader 477 ) 478 if st, ok = d.s.(*stats.InHeader); !ok { 479 t.Fatalf("got %T, want InHeader", d.s) 480 } 481 if d.ctx == nil { 482 t.Fatalf("d.ctx = nil, want <non-nil>") 483 } 484 if st.Compression != e.compression { 485 t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) 486 } 487 if d.client { 488 // additional headers might be injected so instead of testing equality, test that all the 489 // expected headers keys have the expected header values. 490 for key := range testHeaderMetadata { 491 if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) { 492 t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key)) 493 } 494 } 495 } else { 496 if st.FullMethod != e.method { 497 t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) 498 } 499 if st.LocalAddr.String() != e.serverAddr { 500 t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) 501 } 502 // additional headers might be injected so instead of testing equality, test that all the 503 // expected headers keys have the expected header values. 504 for key := range testMetadata { 505 if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) { 506 t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key)) 507 } 508 } 509 510 if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok { 511 if connInfo.RemoteAddr != st.RemoteAddr { 512 t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr) 513 } 514 if connInfo.LocalAddr != st.LocalAddr { 515 t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr) 516 } 517 } else { 518 t.Fatalf("got context %v, want one with connCtxKey", d.ctx) 519 } 520 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { 521 if rpcInfo.FullMethodName != st.FullMethod { 522 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) 523 } 524 } else { 525 t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) 526 } 527 } 528 } 529 530 func checkInPayload(t *testing.T, d *gotData, e *expectedData) { 531 var ( 532 ok bool 533 st *stats.InPayload 534 ) 535 if st, ok = d.s.(*stats.InPayload); !ok { 536 t.Fatalf("got %T, want InPayload", d.s) 537 } 538 if d.ctx == nil { 539 t.Fatalf("d.ctx = nil, want <non-nil>") 540 } 541 if d.client { 542 b, err := proto.Marshal(e.responses[e.respIdx]) 543 if err != nil { 544 t.Fatalf("failed to marshal message: %v", err) 545 } 546 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { 547 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) 548 } 549 e.respIdx++ 550 if string(st.Data) != string(b) { 551 t.Fatalf("st.Data = %v, want %v", st.Data, b) 552 } 553 if st.Length != len(b) { 554 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 555 } 556 } else { 557 b, err := proto.Marshal(e.requests[e.reqIdx]) 558 if err != nil { 559 t.Fatalf("failed to marshal message: %v", err) 560 } 561 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { 562 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) 563 } 564 e.reqIdx++ 565 if string(st.Data) != string(b) { 566 t.Fatalf("st.Data = %v, want %v", st.Data, b) 567 } 568 if st.Length != len(b) { 569 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 570 } 571 } 572 // Below are sanity checks that WireLength and RecvTime are populated. 573 // TODO: check values of WireLength and RecvTime. 574 if len(st.Data) > 0 && st.CompressedLength == 0 { 575 t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>", 576 st.CompressedLength) 577 } 578 if st.RecvTime.IsZero() { 579 t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime) 580 } 581 } 582 583 func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { 584 var ( 585 ok bool 586 st *stats.InTrailer 587 ) 588 if st, ok = d.s.(*stats.InTrailer); !ok { 589 t.Fatalf("got %T, want InTrailer", d.s) 590 } 591 if d.ctx == nil { 592 t.Fatalf("d.ctx = nil, want <non-nil>") 593 } 594 if !st.Client { 595 t.Fatalf("st IsClient = false, want true") 596 } 597 if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) { 598 t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata) 599 } 600 } 601 602 func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { 603 var ( 604 ok bool 605 st *stats.OutHeader 606 ) 607 if st, ok = d.s.(*stats.OutHeader); !ok { 608 t.Fatalf("got %T, want OutHeader", d.s) 609 } 610 if d.ctx == nil { 611 t.Fatalf("d.ctx = nil, want <non-nil>") 612 } 613 if st.Compression != e.compression { 614 t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) 615 } 616 if d.client { 617 if st.FullMethod != e.method { 618 t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) 619 } 620 if st.RemoteAddr.String() != e.serverAddr { 621 t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr) 622 } 623 // additional headers might be injected so instead of testing equality, test that all the 624 // expected headers keys have the expected header values. 625 for key := range testMetadata { 626 if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) { 627 t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key)) 628 } 629 } 630 631 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { 632 if rpcInfo.FullMethodName != st.FullMethod { 633 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) 634 } 635 } else { 636 t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) 637 } 638 } else { 639 // additional headers might be injected so instead of testing equality, test that all the 640 // expected headers keys have the expected header values. 641 for key := range testHeaderMetadata { 642 if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) { 643 t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key)) 644 } 645 } 646 } 647 } 648 649 func checkOutPayload(t *testing.T, d *gotData, e *expectedData) { 650 var ( 651 ok bool 652 st *stats.OutPayload 653 ) 654 if st, ok = d.s.(*stats.OutPayload); !ok { 655 t.Fatalf("got %T, want OutPayload", d.s) 656 } 657 if d.ctx == nil { 658 t.Fatalf("d.ctx = nil, want <non-nil>") 659 } 660 if d.client { 661 b, err := proto.Marshal(e.requests[e.reqIdx]) 662 if err != nil { 663 t.Fatalf("failed to marshal message: %v", err) 664 } 665 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { 666 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) 667 } 668 e.reqIdx++ 669 if string(st.Data) != string(b) { 670 t.Fatalf("st.Data = %v, want %v", st.Data, b) 671 } 672 if st.Length != len(b) { 673 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 674 } 675 } else { 676 b, err := proto.Marshal(e.responses[e.respIdx]) 677 if err != nil { 678 t.Fatalf("failed to marshal message: %v", err) 679 } 680 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { 681 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) 682 } 683 e.respIdx++ 684 if string(st.Data) != string(b) { 685 t.Fatalf("st.Data = %v, want %v", st.Data, b) 686 } 687 if st.Length != len(b) { 688 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 689 } 690 } 691 // Below are sanity checks that WireLength and SentTime are populated. 692 // TODO: check values of WireLength and SentTime. 693 if len(st.Data) > 0 && st.WireLength == 0 { 694 t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>", 695 st.WireLength) 696 } 697 if st.SentTime.IsZero() { 698 t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime) 699 } 700 } 701 702 func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) { 703 var ( 704 ok bool 705 st *stats.OutTrailer 706 ) 707 if st, ok = d.s.(*stats.OutTrailer); !ok { 708 t.Fatalf("got %T, want OutTrailer", d.s) 709 } 710 if d.ctx == nil { 711 t.Fatalf("d.ctx = nil, want <non-nil>") 712 } 713 if st.Client { 714 t.Fatalf("st IsClient = true, want false") 715 } 716 if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) { 717 t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata) 718 } 719 } 720 721 func checkEnd(t *testing.T, d *gotData, e *expectedData) { 722 var ( 723 ok bool 724 st *stats.End 725 ) 726 if st, ok = d.s.(*stats.End); !ok { 727 t.Fatalf("got %T, want End", d.s) 728 } 729 if d.ctx == nil { 730 t.Fatalf("d.ctx = nil, want <non-nil>") 731 } 732 if st.BeginTime.IsZero() { 733 t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime) 734 } 735 if st.EndTime.IsZero() { 736 t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime) 737 } 738 739 actual, ok := status.FromError(st.Error) 740 if !ok { 741 t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error) 742 } 743 744 expectedStatus, _ := status.FromError(e.err) 745 if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() { 746 t.Fatalf("st.Error = %v, want %v", st.Error, e.err) 747 } 748 749 if st.Client { 750 if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) { 751 t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata) 752 } 753 } else { 754 if st.Trailer != nil { 755 t.Fatalf("st.Trailer = %v, want nil", st.Trailer) 756 } 757 } 758 } 759 760 func checkConnBegin(t *testing.T, d *gotData) { 761 var ( 762 ok bool 763 st *stats.ConnBegin 764 ) 765 if st, ok = d.s.(*stats.ConnBegin); !ok { 766 t.Fatalf("got %T, want ConnBegin", d.s) 767 } 768 if d.ctx == nil { 769 t.Fatalf("d.ctx = nil, want <non-nil>") 770 } 771 st.IsClient() // TODO remove this. 772 } 773 774 func checkConnEnd(t *testing.T, d *gotData) { 775 var ( 776 ok bool 777 st *stats.ConnEnd 778 ) 779 if st, ok = d.s.(*stats.ConnEnd); !ok { 780 t.Fatalf("got %T, want ConnEnd", d.s) 781 } 782 if d.ctx == nil { 783 t.Fatalf("d.ctx = nil, want <non-nil>") 784 } 785 st.IsClient() // TODO remove this. 786 } 787 788 type statshandler struct { 789 mu sync.Mutex 790 gotRPC []*gotData 791 gotConn []*gotData 792 } 793 794 func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { 795 return context.WithValue(ctx, connCtxKey{}, info) 796 } 797 798 func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { 799 return context.WithValue(ctx, rpcCtxKey{}, info) 800 } 801 802 func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) { 803 h.mu.Lock() 804 defer h.mu.Unlock() 805 h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s}) 806 } 807 808 func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) { 809 h.mu.Lock() 810 defer h.mu.Unlock() 811 h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s}) 812 } 813 814 func checkConnStats(t *testing.T, got []*gotData) { 815 if len(got) <= 0 || len(got)%2 != 0 { 816 for i, g := range got { 817 t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx) 818 } 819 t.Fatalf("got %v stats, want even positive number", len(got)) 820 } 821 // The first conn stats must be a ConnBegin. 822 checkConnBegin(t, got[0]) 823 // The last conn stats must be a ConnEnd. 824 checkConnEnd(t, got[len(got)-1]) 825 } 826 827 func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { 828 if len(got) != len(checkFuncs) { 829 for i, g := range got { 830 t.Errorf(" - %v, %T", i, g.s) 831 } 832 t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) 833 } 834 835 for i, f := range checkFuncs { 836 f(t, got[i], expect) 837 } 838 } 839 840 func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { 841 h := &statshandler{} 842 te := newTest(t, tc, nil, []stats.Handler{h}) 843 te.startServer(&testServer{}) 844 defer te.tearDown() 845 846 var ( 847 reqs []proto.Message 848 resps []proto.Message 849 err error 850 method string 851 852 isClientStream bool 853 isServerStream bool 854 855 req proto.Message 856 resp proto.Message 857 e error 858 ) 859 860 switch cc.callType { 861 case unaryRPC: 862 method = "/grpc.testing.TestService/UnaryCall" 863 req, resp, e = te.doUnaryCall(cc) 864 reqs = []proto.Message{req} 865 resps = []proto.Message{resp} 866 err = e 867 case clientStreamRPC: 868 method = "/grpc.testing.TestService/StreamingInputCall" 869 reqs, resp, e = te.doClientStreamCall(cc) 870 resps = []proto.Message{resp} 871 err = e 872 isClientStream = true 873 case serverStreamRPC: 874 method = "/grpc.testing.TestService/StreamingOutputCall" 875 req, resps, e = te.doServerStreamCall(cc) 876 reqs = []proto.Message{req} 877 err = e 878 isServerStream = true 879 case fullDuplexStreamRPC: 880 method = "/grpc.testing.TestService/FullDuplexCall" 881 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) 882 isClientStream = true 883 isServerStream = true 884 } 885 if cc.success != (err == nil) { 886 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 887 } 888 te.cc.Close() 889 te.srv.GracefulStop() // Wait for the server to stop. 890 891 for { 892 h.mu.Lock() 893 if len(h.gotRPC) >= len(checkFuncs) { 894 h.mu.Unlock() 895 break 896 } 897 h.mu.Unlock() 898 time.Sleep(10 * time.Millisecond) 899 } 900 901 for { 902 h.mu.Lock() 903 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { 904 h.mu.Unlock() 905 break 906 } 907 h.mu.Unlock() 908 time.Sleep(10 * time.Millisecond) 909 } 910 911 expect := &expectedData{ 912 serverAddr: te.srvAddr, 913 compression: tc.compress, 914 method: method, 915 requests: reqs, 916 responses: resps, 917 err: err, 918 isClientStream: isClientStream, 919 isServerStream: isServerStream, 920 } 921 922 h.mu.Lock() 923 checkConnStats(t, h.gotConn) 924 h.mu.Unlock() 925 checkServerStats(t, h.gotRPC, expect, checkFuncs) 926 } 927 928 func (s) TestServerStatsUnaryRPC(t *testing.T) { 929 testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 930 checkInHeader, 931 checkBegin, 932 checkInPayload, 933 checkOutHeader, 934 checkOutPayload, 935 checkOutTrailer, 936 checkEnd, 937 }) 938 } 939 940 func (s) TestServerStatsUnaryRPCError(t *testing.T) { 941 testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 942 checkInHeader, 943 checkBegin, 944 checkInPayload, 945 checkOutHeader, 946 checkOutTrailer, 947 checkEnd, 948 }) 949 } 950 951 func (s) TestServerStatsClientStreamRPC(t *testing.T) { 952 count := 5 953 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 954 checkInHeader, 955 checkBegin, 956 checkOutHeader, 957 } 958 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 959 checkInPayload, 960 } 961 for i := 0; i < count; i++ { 962 checkFuncs = append(checkFuncs, ioPayFuncs...) 963 } 964 checkFuncs = append(checkFuncs, 965 checkOutPayload, 966 checkOutTrailer, 967 checkEnd, 968 ) 969 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs) 970 } 971 972 func (s) TestServerStatsClientStreamRPCError(t *testing.T) { 973 count := 1 974 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 975 checkInHeader, 976 checkBegin, 977 checkOutHeader, 978 checkInPayload, 979 checkOutTrailer, 980 checkEnd, 981 }) 982 } 983 984 func (s) TestServerStatsServerStreamRPC(t *testing.T) { 985 count := 5 986 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 987 checkInHeader, 988 checkBegin, 989 checkInPayload, 990 checkOutHeader, 991 } 992 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 993 checkOutPayload, 994 } 995 for i := 0; i < count; i++ { 996 checkFuncs = append(checkFuncs, ioPayFuncs...) 997 } 998 checkFuncs = append(checkFuncs, 999 checkOutTrailer, 1000 checkEnd, 1001 ) 1002 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs) 1003 } 1004 1005 func (s) TestServerStatsServerStreamRPCError(t *testing.T) { 1006 count := 5 1007 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 1008 checkInHeader, 1009 checkBegin, 1010 checkInPayload, 1011 checkOutHeader, 1012 checkOutTrailer, 1013 checkEnd, 1014 }) 1015 } 1016 1017 func (s) TestServerStatsFullDuplexRPC(t *testing.T) { 1018 count := 5 1019 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 1020 checkInHeader, 1021 checkBegin, 1022 checkOutHeader, 1023 } 1024 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 1025 checkInPayload, 1026 checkOutPayload, 1027 } 1028 for i := 0; i < count; i++ { 1029 checkFuncs = append(checkFuncs, ioPayFuncs...) 1030 } 1031 checkFuncs = append(checkFuncs, 1032 checkOutTrailer, 1033 checkEnd, 1034 ) 1035 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs) 1036 } 1037 1038 func (s) TestServerStatsFullDuplexRPCError(t *testing.T) { 1039 count := 5 1040 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 1041 checkInHeader, 1042 checkBegin, 1043 checkOutHeader, 1044 checkInPayload, 1045 checkOutTrailer, 1046 checkEnd, 1047 }) 1048 } 1049 1050 type checkFuncWithCount struct { 1051 f func(t *testing.T, d *gotData, e *expectedData) 1052 c int // expected count 1053 } 1054 1055 func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) { 1056 var expectLen int 1057 for _, v := range checkFuncs { 1058 expectLen += v.c 1059 } 1060 if len(got) != expectLen { 1061 for i, g := range got { 1062 t.Errorf(" - %v, %T", i, g.s) 1063 } 1064 t.Fatalf("got %v stats, want %v stats", len(got), expectLen) 1065 } 1066 1067 var tagInfoInCtx *stats.RPCTagInfo 1068 for i := 0; i < len(got); i++ { 1069 if _, ok := got[i].s.(stats.RPCStats); ok { 1070 tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo) 1071 if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew { 1072 t.Fatalf("got context containing different tagInfo with stats %T", got[i].s) 1073 } 1074 tagInfoInCtx = tagInfoInCtxNew 1075 } 1076 } 1077 1078 for _, s := range got { 1079 switch s.s.(type) { 1080 case *stats.Begin: 1081 if checkFuncs[begin].c <= 0 { 1082 t.Fatalf("unexpected stats: %T", s.s) 1083 } 1084 checkFuncs[begin].f(t, s, expect) 1085 checkFuncs[begin].c-- 1086 case *stats.OutHeader: 1087 if checkFuncs[outHeader].c <= 0 { 1088 t.Fatalf("unexpected stats: %T", s.s) 1089 } 1090 checkFuncs[outHeader].f(t, s, expect) 1091 checkFuncs[outHeader].c-- 1092 case *stats.OutPayload: 1093 if checkFuncs[outPayload].c <= 0 { 1094 t.Fatalf("unexpected stats: %T", s.s) 1095 } 1096 checkFuncs[outPayload].f(t, s, expect) 1097 checkFuncs[outPayload].c-- 1098 case *stats.InHeader: 1099 if checkFuncs[inHeader].c <= 0 { 1100 t.Fatalf("unexpected stats: %T", s.s) 1101 } 1102 checkFuncs[inHeader].f(t, s, expect) 1103 checkFuncs[inHeader].c-- 1104 case *stats.InPayload: 1105 if checkFuncs[inPayload].c <= 0 { 1106 t.Fatalf("unexpected stats: %T", s.s) 1107 } 1108 checkFuncs[inPayload].f(t, s, expect) 1109 checkFuncs[inPayload].c-- 1110 case *stats.InTrailer: 1111 if checkFuncs[inTrailer].c <= 0 { 1112 t.Fatalf("unexpected stats: %T", s.s) 1113 } 1114 checkFuncs[inTrailer].f(t, s, expect) 1115 checkFuncs[inTrailer].c-- 1116 case *stats.End: 1117 if checkFuncs[end].c <= 0 { 1118 t.Fatalf("unexpected stats: %T", s.s) 1119 } 1120 checkFuncs[end].f(t, s, expect) 1121 checkFuncs[end].c-- 1122 case *stats.ConnBegin: 1123 if checkFuncs[connBegin].c <= 0 { 1124 t.Fatalf("unexpected stats: %T", s.s) 1125 } 1126 checkFuncs[connBegin].f(t, s, expect) 1127 checkFuncs[connBegin].c-- 1128 case *stats.ConnEnd: 1129 if checkFuncs[connEnd].c <= 0 { 1130 t.Fatalf("unexpected stats: %T", s.s) 1131 } 1132 checkFuncs[connEnd].f(t, s, expect) 1133 checkFuncs[connEnd].c-- 1134 default: 1135 t.Fatalf("unexpected stats: %T", s.s) 1136 } 1137 } 1138 } 1139 1140 func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) { 1141 h := &statshandler{} 1142 te := newTest(t, tc, []stats.Handler{h}, nil) 1143 te.startServer(&testServer{}) 1144 defer te.tearDown() 1145 1146 var ( 1147 reqs []proto.Message 1148 resps []proto.Message 1149 method string 1150 err error 1151 1152 isClientStream bool 1153 isServerStream bool 1154 1155 req proto.Message 1156 resp proto.Message 1157 e error 1158 ) 1159 switch cc.callType { 1160 case unaryRPC: 1161 method = "/grpc.testing.TestService/UnaryCall" 1162 req, resp, e = te.doUnaryCall(cc) 1163 reqs = []proto.Message{req} 1164 resps = []proto.Message{resp} 1165 err = e 1166 case clientStreamRPC: 1167 method = "/grpc.testing.TestService/StreamingInputCall" 1168 reqs, resp, e = te.doClientStreamCall(cc) 1169 resps = []proto.Message{resp} 1170 err = e 1171 isClientStream = true 1172 case serverStreamRPC: 1173 method = "/grpc.testing.TestService/StreamingOutputCall" 1174 req, resps, e = te.doServerStreamCall(cc) 1175 reqs = []proto.Message{req} 1176 err = e 1177 isServerStream = true 1178 case fullDuplexStreamRPC: 1179 method = "/grpc.testing.TestService/FullDuplexCall" 1180 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) 1181 isClientStream = true 1182 isServerStream = true 1183 } 1184 if cc.success != (err == nil) { 1185 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 1186 } 1187 te.cc.Close() 1188 te.srv.GracefulStop() // Wait for the server to stop. 1189 1190 lenRPCStats := 0 1191 for _, v := range checkFuncs { 1192 lenRPCStats += v.c 1193 } 1194 for { 1195 h.mu.Lock() 1196 if len(h.gotRPC) >= lenRPCStats { 1197 h.mu.Unlock() 1198 break 1199 } 1200 h.mu.Unlock() 1201 time.Sleep(10 * time.Millisecond) 1202 } 1203 1204 for { 1205 h.mu.Lock() 1206 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { 1207 h.mu.Unlock() 1208 break 1209 } 1210 h.mu.Unlock() 1211 time.Sleep(10 * time.Millisecond) 1212 } 1213 1214 expect := &expectedData{ 1215 serverAddr: te.srvAddr, 1216 compression: tc.compress, 1217 method: method, 1218 requests: reqs, 1219 responses: resps, 1220 failfast: cc.failfast, 1221 err: err, 1222 isClientStream: isClientStream, 1223 isServerStream: isServerStream, 1224 } 1225 1226 h.mu.Lock() 1227 checkConnStats(t, h.gotConn) 1228 h.mu.Unlock() 1229 checkClientStats(t, h.gotRPC, expect, checkFuncs) 1230 } 1231 1232 func (s) TestClientStatsUnaryRPC(t *testing.T) { 1233 testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ 1234 begin: {checkBegin, 1}, 1235 outHeader: {checkOutHeader, 1}, 1236 outPayload: {checkOutPayload, 1}, 1237 inHeader: {checkInHeader, 1}, 1238 inPayload: {checkInPayload, 1}, 1239 inTrailer: {checkInTrailer, 1}, 1240 end: {checkEnd, 1}, 1241 }) 1242 } 1243 1244 func (s) TestClientStatsUnaryRPCError(t *testing.T) { 1245 testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ 1246 begin: {checkBegin, 1}, 1247 outHeader: {checkOutHeader, 1}, 1248 outPayload: {checkOutPayload, 1}, 1249 inHeader: {checkInHeader, 1}, 1250 inTrailer: {checkInTrailer, 1}, 1251 end: {checkEnd, 1}, 1252 }) 1253 } 1254 1255 func (s) TestClientStatsClientStreamRPC(t *testing.T) { 1256 count := 5 1257 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ 1258 begin: {checkBegin, 1}, 1259 outHeader: {checkOutHeader, 1}, 1260 inHeader: {checkInHeader, 1}, 1261 outPayload: {checkOutPayload, count}, 1262 inTrailer: {checkInTrailer, 1}, 1263 inPayload: {checkInPayload, 1}, 1264 end: {checkEnd, 1}, 1265 }) 1266 } 1267 1268 func (s) TestClientStatsClientStreamRPCError(t *testing.T) { 1269 count := 1 1270 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ 1271 begin: {checkBegin, 1}, 1272 outHeader: {checkOutHeader, 1}, 1273 inHeader: {checkInHeader, 1}, 1274 outPayload: {checkOutPayload, 1}, 1275 inTrailer: {checkInTrailer, 1}, 1276 end: {checkEnd, 1}, 1277 }) 1278 } 1279 1280 func (s) TestClientStatsServerStreamRPC(t *testing.T) { 1281 count := 5 1282 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ 1283 begin: {checkBegin, 1}, 1284 outHeader: {checkOutHeader, 1}, 1285 outPayload: {checkOutPayload, 1}, 1286 inHeader: {checkInHeader, 1}, 1287 inPayload: {checkInPayload, count}, 1288 inTrailer: {checkInTrailer, 1}, 1289 end: {checkEnd, 1}, 1290 }) 1291 } 1292 1293 func (s) TestClientStatsServerStreamRPCError(t *testing.T) { 1294 count := 5 1295 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ 1296 begin: {checkBegin, 1}, 1297 outHeader: {checkOutHeader, 1}, 1298 outPayload: {checkOutPayload, 1}, 1299 inHeader: {checkInHeader, 1}, 1300 inTrailer: {checkInTrailer, 1}, 1301 end: {checkEnd, 1}, 1302 }) 1303 } 1304 1305 func (s) TestClientStatsFullDuplexRPC(t *testing.T) { 1306 count := 5 1307 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ 1308 begin: {checkBegin, 1}, 1309 outHeader: {checkOutHeader, 1}, 1310 outPayload: {checkOutPayload, count}, 1311 inHeader: {checkInHeader, 1}, 1312 inPayload: {checkInPayload, count}, 1313 inTrailer: {checkInTrailer, 1}, 1314 end: {checkEnd, 1}, 1315 }) 1316 } 1317 1318 func (s) TestClientStatsFullDuplexRPCError(t *testing.T) { 1319 count := 5 1320 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ 1321 begin: {checkBegin, 1}, 1322 outHeader: {checkOutHeader, 1}, 1323 outPayload: {checkOutPayload, 1}, 1324 inHeader: {checkInHeader, 1}, 1325 inTrailer: {checkInTrailer, 1}, 1326 end: {checkEnd, 1}, 1327 }) 1328 } 1329 1330 func (s) TestTags(t *testing.T) { 1331 b := []byte{5, 2, 4, 3, 1} 1332 tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1333 defer cancel() 1334 ctx := stats.SetTags(tCtx, b) 1335 if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) { 1336 t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b) 1337 } 1338 if tg := stats.Tags(ctx); tg != nil { 1339 t.Errorf("Tags(%v) = %v; want nil", ctx, tg) 1340 } 1341 1342 ctx = stats.SetIncomingTags(tCtx, b) 1343 if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) { 1344 t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b) 1345 } 1346 if tg := stats.OutgoingTags(ctx); tg != nil { 1347 t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg) 1348 } 1349 } 1350 1351 func (s) TestTrace(t *testing.T) { 1352 b := []byte{5, 2, 4, 3, 1} 1353 tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1354 defer cancel() 1355 ctx := stats.SetTrace(tCtx, b) 1356 if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) { 1357 t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b) 1358 } 1359 if tr := stats.Trace(ctx); tr != nil { 1360 t.Errorf("Trace(%v) = %v; want nil", ctx, tr) 1361 } 1362 1363 ctx = stats.SetIncomingTrace(tCtx, b) 1364 if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) { 1365 t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b) 1366 } 1367 if tr := stats.OutgoingTrace(ctx); tr != nil { 1368 t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr) 1369 } 1370 } 1371 1372 func (s) TestMultipleClientStatsHandler(t *testing.T) { 1373 h := &statshandler{} 1374 tc := &testConfig{compress: ""} 1375 te := newTest(t, tc, []stats.Handler{h, h}, nil) 1376 te.startServer(&testServer{}) 1377 defer te.tearDown() 1378 1379 cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC} 1380 _, _, err := te.doUnaryCall(cc) 1381 if cc.success != (err == nil) { 1382 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 1383 } 1384 te.cc.Close() 1385 te.srv.GracefulStop() // Wait for the server to stop. 1386 1387 for start := time.Now(); time.Since(start) < defaultTestTimeout; { 1388 h.mu.Lock() 1389 if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok && len(h.gotRPC) == 12 { 1390 h.mu.Unlock() 1391 break 1392 } 1393 h.mu.Unlock() 1394 time.Sleep(10 * time.Millisecond) 1395 } 1396 1397 for start := time.Now(); time.Since(start) < defaultTestTimeout; { 1398 h.mu.Lock() 1399 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok && len(h.gotConn) == 4 { 1400 h.mu.Unlock() 1401 break 1402 } 1403 h.mu.Unlock() 1404 time.Sleep(10 * time.Millisecond) 1405 } 1406 1407 // Each RPC generates 6 stats events on the client-side, times 2 StatsHandler 1408 if len(h.gotRPC) != 12 { 1409 t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) 1410 } 1411 1412 // Each connection generates 4 conn events on the client-side, times 2 StatsHandler 1413 if len(h.gotConn) != 4 { 1414 t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) 1415 } 1416 } 1417 1418 func (s) TestMultipleServerStatsHandler(t *testing.T) { 1419 h := &statshandler{} 1420 tc := &testConfig{compress: ""} 1421 te := newTest(t, tc, nil, []stats.Handler{h, h}) 1422 te.startServer(&testServer{}) 1423 defer te.tearDown() 1424 1425 cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC} 1426 _, _, err := te.doUnaryCall(cc) 1427 if cc.success != (err == nil) { 1428 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 1429 } 1430 te.cc.Close() 1431 te.srv.GracefulStop() // Wait for the server to stop. 1432 1433 for start := time.Now(); time.Since(start) < defaultTestTimeout; { 1434 h.mu.Lock() 1435 if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { 1436 h.mu.Unlock() 1437 break 1438 } 1439 h.mu.Unlock() 1440 time.Sleep(10 * time.Millisecond) 1441 } 1442 1443 for start := time.Now(); time.Since(start) < defaultTestTimeout; { 1444 h.mu.Lock() 1445 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { 1446 h.mu.Unlock() 1447 break 1448 } 1449 h.mu.Unlock() 1450 time.Sleep(10 * time.Millisecond) 1451 } 1452 1453 // Each RPC generates 6 stats events on the server-side, times 2 StatsHandler 1454 if len(h.gotRPC) != 12 { 1455 t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) 1456 } 1457 1458 // Each connection generates 4 conn events on the server-side, times 2 StatsHandler 1459 if len(h.gotConn) != 4 { 1460 t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) 1461 } 1462 } 1463 1464 // TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler 1465 // gets access to a Server on the server side, and thus the method that the 1466 // server owns which specifies whether a method is made or not. The test sets up 1467 // a server with a unary call and full duplex call configured, and makes an RPC. 1468 // Within the stats handler, asking the server whether unary or duplex method 1469 // names are registered should return true, and any other query should return 1470 // false. 1471 func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) { 1472 wg := sync.WaitGroup{} 1473 wg.Add(1) 1474 stubStatsHandler := &testutils.StubStatsHandler{ 1475 TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context { 1476 // OpenTelemetry instrumentation needs the passed in Server to determine if 1477 // methods are registered in different handle calls in to record metrics. 1478 // This tag RPC call context gets passed into every handle call, so can 1479 // assert once here, since it maps to all the handle RPC calls that come 1480 // after. These internal calls will be how the OpenTelemetry instrumentation 1481 // component accesses this server and the subsequent helper on the server. 1482 server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx) 1483 if server == nil { 1484 t.Errorf("stats handler received ctx has no server present") 1485 } 1486 isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool) 1487 // /s/m and s/m are valid. 1488 if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") { 1489 t.Errorf("UnaryCall should be a registered method according to server") 1490 } 1491 if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") { 1492 t.Errorf("FullDuplexCall should be a registered method according to server") 1493 } 1494 if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") { 1495 t.Errorf("DoesNotExistCall should not be a registered method according to server") 1496 } 1497 if isRegisteredMethod(server, "/unknownService/UnaryCall") { 1498 t.Errorf("/unknownService/UnaryCall should not be a registered method according to server") 1499 } 1500 wg.Done() 1501 return ctx 1502 }, 1503 } 1504 ss := &stubserver.StubServer{ 1505 UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 1506 return &testpb.SimpleResponse{}, nil 1507 }, 1508 } 1509 if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil { 1510 t.Fatalf("Error starting endpoint server: %v", err) 1511 } 1512 defer ss.Stop() 1513 1514 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 1515 defer cancel() 1516 if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil { 1517 t.Fatalf("Unexpected error from UnaryCall: %v", err) 1518 } 1519 wg.Wait() 1520 }