google.golang.org/grpc@v1.74.2/stats/stats_test.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package stats_test
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"reflect"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/google/go-cmp/cmp"
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/connectivity"
    34  	"google.golang.org/grpc/credentials/insecure"
    35  	"google.golang.org/grpc/internal"
    36  	"google.golang.org/grpc/internal/grpctest"
    37  	"google.golang.org/grpc/internal/stubserver"
    38  	"google.golang.org/grpc/internal/testutils"
    39  	"google.golang.org/grpc/metadata"
    40  	"google.golang.org/grpc/stats"
    41  	"google.golang.org/grpc/status"
    42  	"google.golang.org/protobuf/proto"
    43  	"google.golang.org/protobuf/testing/protocmp"
    44  
    45  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    46  	testpb "google.golang.org/grpc/interop/grpc_testing"
    47  )
    48  
    49  const defaultTestTimeout = 10 * time.Second
    50  
    51  type s struct {
    52  	grpctest.Tester
    53  }
    54  
    55  func Test(t *testing.T) {
    56  	grpctest.RunSubTests(t, s{})
    57  }
    58  
    59  func init() {
    60  	grpc.EnableTracing = false
    61  }
    62  
    63  type connCtxKey struct{}
    64  type rpcCtxKey struct{}
    65  
    66  var (
    67  	// For headers sent to server:
    68  	testMetadata = metadata.MD{
    69  		"key1":       []string{"value1"},
    70  		"key2":       []string{"value2"},
    71  		"user-agent": []string{fmt.Sprintf("test/0.0.1 grpc-go/%s", grpc.Version)},
    72  	}
    73  	// For headers sent from server:
    74  	testHeaderMetadata = metadata.MD{
    75  		"hkey1": []string{"headerValue1"},
    76  		"hkey2": []string{"headerValue2"},
    77  	}
    78  	// For trailers sent from server:
    79  	testTrailerMetadata = metadata.MD{
    80  		"tkey1": []string{"trailerValue1"},
    81  		"tkey2": []string{"trailerValue2"},
    82  	}
    83  	// The id for which the service handler should return error.
    84  	errorID int32 = 32202
    85  )
    86  
    87  func idToPayload(id int32) *testpb.Payload {
    88  	return &testpb.Payload{Body: []byte{byte(id), byte(id >> 8), byte(id >> 16), byte(id >> 24)}}
    89  }
    90  
    91  func payloadToID(p *testpb.Payload) int32 {
    92  	if p == nil || len(p.Body) != 4 {
    93  		panic("invalid payload")
    94  	}
    95  	return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24
    96  }
    97  
    98  func setIncomingStats(ctx context.Context, mdKey string, b []byte) context.Context {
    99  	md, ok := metadata.FromIncomingContext(ctx)
   100  	if !ok {
   101  		md = metadata.MD{}
   102  	}
   103  	md.Set(mdKey, string(b))
   104  	return metadata.NewIncomingContext(ctx, md)
   105  }
   106  
   107  func getOutgoingStats(ctx context.Context, mdKey string) []byte {
   108  	md, ok := metadata.FromOutgoingContext(ctx)
   109  	if !ok {
   110  		return nil
   111  	}
   112  	tagValues := md.Get(mdKey)
   113  	if len(tagValues) == 0 {
   114  		return nil
   115  	}
   116  	return []byte(tagValues[len(tagValues)-1])
   117  }
   118  
   119  type testServer struct {
   120  	testgrpc.UnimplementedTestServiceServer
   121  }
   122  
   123  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   124  	if err := grpc.SendHeader(ctx, testHeaderMetadata); err != nil {
   125  		return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", testHeaderMetadata, err)
   126  	}
   127  	if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
   128  		return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
   129  	}
   130  
   131  	if id := payloadToID(in.Payload); id == errorID {
   132  		return nil, fmt.Errorf("got error id: %v", id)
   133  	}
   134  
   135  	return &testpb.SimpleResponse{Payload: in.Payload}, nil
   136  }
   137  
   138  func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
   139  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   140  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   141  	}
   142  	stream.SetTrailer(testTrailerMetadata)
   143  	for {
   144  		in, err := stream.Recv()
   145  		if err == io.EOF {
   146  			// read done.
   147  			return nil
   148  		}
   149  		if err != nil {
   150  			return err
   151  		}
   152  
   153  		if id := payloadToID(in.Payload); id == errorID {
   154  			return fmt.Errorf("got error id: %v", id)
   155  		}
   156  
   157  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   158  			return err
   159  		}
   160  	}
   161  }
   162  
   163  func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
   164  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   165  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   166  	}
   167  	stream.SetTrailer(testTrailerMetadata)
   168  	for {
   169  		in, err := stream.Recv()
   170  		if err == io.EOF {
   171  			// read done.
   172  			return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
   173  		}
   174  		if err != nil {
   175  			return err
   176  		}
   177  
   178  		if id := payloadToID(in.Payload); id == errorID {
   179  			return fmt.Errorf("got error id: %v", id)
   180  		}
   181  	}
   182  }
   183  
   184  func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
   185  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   186  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   187  	}
   188  	stream.SetTrailer(testTrailerMetadata)
   189  
   190  	if id := payloadToID(in.Payload); id == errorID {
   191  		return fmt.Errorf("got error id: %v", id)
   192  	}
   193  
   194  	for i := 0; i < 5; i++ {
   195  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   196  			return err
   197  		}
   198  	}
   199  	return nil
   200  }
   201  
   202  // test is an end-to-end test. It should be created with the newTest
   203  // func, modified as needed, and then started with its startServer method.
   204  // It should be cleaned up with the tearDown method.
   205  type test struct {
   206  	t                   *testing.T
   207  	compress            string
   208  	clientStatsHandlers []stats.Handler
   209  	serverStatsHandlers []stats.Handler
   210  
   211  	testServer testgrpc.TestServiceServer // nil means none
   212  	// srv and srvAddr are set once startServer is called.
   213  	srv     *grpc.Server
   214  	srvAddr string
   215  
   216  	cc *grpc.ClientConn // nil until requested via clientConn
   217  }
   218  
   219  func (te *test) tearDown() {
   220  	if te.cc != nil {
   221  		te.cc.Close()
   222  		te.cc = nil
   223  	}
   224  	te.srv.Stop()
   225  }
   226  
   227  type testConfig struct {
   228  	compress string
   229  }
   230  
   231  // newTest returns a new test using the provided testing.T and
   232  // environment.  It is returned with default values. Tests should
   233  // modify it before calling its startServer and clientConn methods.
   234  func newTest(t *testing.T, tc *testConfig, chs []stats.Handler, shs []stats.Handler) *test {
   235  	te := &test{
   236  		t:                   t,
   237  		compress:            tc.compress,
   238  		clientStatsHandlers: chs,
   239  		serverStatsHandlers: shs,
   240  	}
   241  	return te
   242  }
   243  
   244  // startServer starts a gRPC server listening. Callers should defer a
   245  // call to te.tearDown to clean up.
   246  func (te *test) startServer(ts testgrpc.TestServiceServer) {
   247  	te.testServer = ts
   248  	lis, err := net.Listen("tcp", "localhost:0")
   249  	if err != nil {
   250  		te.t.Fatalf("Failed to listen: %v", err)
   251  	}
   252  	var opts []grpc.ServerOption
   253  	if te.compress == "gzip" {
   254  		opts = append(opts,
   255  			grpc.RPCCompressor(grpc.NewGZIPCompressor()),
   256  			grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
   257  		)
   258  	}
   259  	for _, sh := range te.serverStatsHandlers {
   260  		opts = append(opts, grpc.StatsHandler(sh))
   261  	}
   262  	s := grpc.NewServer(opts...)
   263  	te.srv = s
   264  	if te.testServer != nil {
   265  		testgrpc.RegisterTestServiceServer(s, te.testServer)
   266  	}
   267  
   268  	go s.Serve(lis)
   269  	te.srvAddr = lis.Addr().String()
   270  }
   271  
   272  func (te *test) clientConn(ctx context.Context) *grpc.ClientConn {
   273  	if te.cc != nil {
   274  		return te.cc
   275  	}
   276  	opts := []grpc.DialOption{
   277  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   278  		grpc.WithUserAgent("test/0.0.1"),
   279  	}
   280  	if te.compress == "gzip" {
   281  		opts = append(opts,
   282  			grpc.WithCompressor(grpc.NewGZIPCompressor()),
   283  			grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
   284  		)
   285  	}
   286  	for _, sh := range te.clientStatsHandlers {
   287  		opts = append(opts, grpc.WithStatsHandler(sh))
   288  	}
   289  
   290  	var err error
   291  	te.cc, err = grpc.NewClient(te.srvAddr, opts...)
   292  	if err != nil {
   293  		te.t.Fatalf("grpc.NewClient(%q) failed: %v", te.srvAddr, err)
   294  	}
   295  	te.cc.Connect()
   296  	testutils.AwaitState(ctx, te.t, te.cc, connectivity.Ready)
   297  	return te.cc
   298  }
   299  
   300  type rpcType int
   301  
   302  const (
   303  	unaryRPC rpcType = iota
   304  	clientStreamRPC
   305  	serverStreamRPC
   306  	fullDuplexStreamRPC
   307  )
   308  
   309  type rpcConfig struct {
   310  	count    int  // Number of requests and responses for streaming RPCs.
   311  	success  bool // Whether the RPC should succeed or return error.
   312  	failfast bool
   313  	callType rpcType // Type of RPC.
   314  }
   315  
   316  func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
   317  	var (
   318  		resp *testpb.SimpleResponse
   319  		req  *testpb.SimpleRequest
   320  		err  error
   321  	)
   322  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   323  	defer cancel()
   324  	tc := testgrpc.NewTestServiceClient(te.clientConn(tCtx))
   325  	if c.success {
   326  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)}
   327  	} else {
   328  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID)}
   329  	}
   330  
   331  	resp, err = tc.UnaryCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   332  	return req, resp, err
   333  }
   334  
   335  func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []proto.Message, error) {
   336  	var (
   337  		reqs  []proto.Message
   338  		resps []proto.Message
   339  		err   error
   340  	)
   341  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   342  	defer cancel()
   343  	tc := testgrpc.NewTestServiceClient(te.clientConn(tCtx))
   344  	stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   345  	if err != nil {
   346  		return reqs, resps, err
   347  	}
   348  	var startID int32
   349  	if !c.success {
   350  		startID = errorID
   351  	}
   352  	for i := 0; i < c.count; i++ {
   353  		req := &testpb.StreamingOutputCallRequest{
   354  			Payload: idToPayload(int32(i) + startID),
   355  		}
   356  		reqs = append(reqs, req)
   357  		if err = stream.Send(req); err != nil {
   358  			return reqs, resps, err
   359  		}
   360  		var resp *testpb.StreamingOutputCallResponse
   361  		if resp, err = stream.Recv(); err != nil {
   362  			return reqs, resps, err
   363  		}
   364  		resps = append(resps, resp)
   365  	}
   366  	if err = stream.CloseSend(); err != nil && err != io.EOF {
   367  		return reqs, resps, err
   368  	}
   369  	if _, err = stream.Recv(); err != io.EOF {
   370  		return reqs, resps, err
   371  	}
   372  
   373  	return reqs, resps, nil
   374  }
   375  
   376  func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.StreamingInputCallResponse, error) {
   377  	var (
   378  		reqs []proto.Message
   379  		resp *testpb.StreamingInputCallResponse
   380  		err  error
   381  	)
   382  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   383  	defer cancel()
   384  	tc := testgrpc.NewTestServiceClient(te.clientConn(tCtx))
   385  	stream, err := tc.StreamingInputCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   386  	if err != nil {
   387  		return reqs, resp, err
   388  	}
   389  	var startID int32
   390  	if !c.success {
   391  		startID = errorID
   392  	}
   393  	for i := 0; i < c.count; i++ {
   394  		req := &testpb.StreamingInputCallRequest{
   395  			Payload: idToPayload(int32(i) + startID),
   396  		}
   397  		reqs = append(reqs, req)
   398  		if err = stream.Send(req); err != nil {
   399  			return reqs, resp, err
   400  		}
   401  	}
   402  	resp, err = stream.CloseAndRecv()
   403  	return reqs, resp, err
   404  }
   405  
   406  func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallRequest, []proto.Message, error) {
   407  	var (
   408  		req   *testpb.StreamingOutputCallRequest
   409  		resps []proto.Message
   410  		err   error
   411  	)
   412  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   413  	defer cancel()
   414  	tc := testgrpc.NewTestServiceClient(te.clientConn(tCtx))
   415  
   416  	var startID int32
   417  	if !c.success {
   418  		startID = errorID
   419  	}
   420  	req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)}
   421  	stream, err := tc.StreamingOutputCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   422  	if err != nil {
   423  		return req, resps, err
   424  	}
   425  	for {
   426  		var resp *testpb.StreamingOutputCallResponse
   427  		resp, err := stream.Recv()
   428  		if err == io.EOF {
   429  			return req, resps, nil
   430  		} else if err != nil {
   431  			return req, resps, err
   432  		}
   433  		resps = append(resps, resp)
   434  	}
   435  }
   436  
   437  type expectedData struct {
   438  	method         string
   439  	isClientStream bool
   440  	isServerStream bool
   441  	serverAddr     string
   442  	compression    string
   443  	reqIdx         int
   444  	requests       []proto.Message
   445  	respIdx        int
   446  	responses      []proto.Message
   447  	err            error
   448  	failfast       bool
   449  }
   450  
   451  type gotData struct {
   452  	ctx    context.Context
   453  	client bool
   454  	s      any // This could be RPCStats or ConnStats.
   455  }
   456  
   457  const (
   458  	begin int = iota
   459  	end
   460  	inPayload
   461  	inHeader
   462  	inTrailer
   463  	outPayload
   464  	outHeader
   465  	// TODO: test outTrailer ?
   466  	connBegin
   467  	connEnd
   468  )
   469  
   470  func checkBegin(t *testing.T, d *gotData, e *expectedData) {
   471  	var (
   472  		ok bool
   473  		st *stats.Begin
   474  	)
   475  	if st, ok = d.s.(*stats.Begin); !ok {
   476  		t.Fatalf("got %T, want Begin", d.s)
   477  	}
   478  	if d.ctx == nil {
   479  		t.Fatalf("d.ctx = nil, want <non-nil>")
   480  	}
   481  	if st.BeginTime.IsZero() {
   482  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   483  	}
   484  	if d.client {
   485  		if st.FailFast != e.failfast {
   486  			t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
   487  		}
   488  	}
   489  	if st.IsClientStream != e.isClientStream {
   490  		t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream)
   491  	}
   492  	if st.IsServerStream != e.isServerStream {
   493  		t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream)
   494  	}
   495  }
   496  
   497  func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
   498  	var (
   499  		ok bool
   500  		st *stats.InHeader
   501  	)
   502  	if st, ok = d.s.(*stats.InHeader); !ok {
   503  		t.Fatalf("got %T, want InHeader", d.s)
   504  	}
   505  	if d.ctx == nil {
   506  		t.Fatalf("d.ctx = nil, want <non-nil>")
   507  	}
   508  	if st.Compression != e.compression {
   509  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   510  	}
   511  	if d.client {
   512  		// additional headers might be injected so instead of testing equality, test that all the
   513  		// expected headers keys have the expected header values.
   514  		for key := range testHeaderMetadata {
   515  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   516  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   517  			}
   518  		}
   519  	} else {
   520  		if st.FullMethod != e.method {
   521  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   522  		}
   523  		if st.LocalAddr.String() != e.serverAddr {
   524  			t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
   525  		}
   526  		// additional headers might be injected so instead of testing equality, test that all the
   527  		// expected headers keys have the expected header values.
   528  		for key := range testMetadata {
   529  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   530  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   531  			}
   532  		}
   533  
   534  		if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
   535  			if connInfo.RemoteAddr != st.RemoteAddr {
   536  				t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
   537  			}
   538  			if connInfo.LocalAddr != st.LocalAddr {
   539  				t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
   540  			}
   541  		} else {
   542  			t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
   543  		}
   544  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   545  			if rpcInfo.FullMethodName != st.FullMethod {
   546  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   547  			}
   548  		} else {
   549  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   550  		}
   551  	}
   552  }
   553  
   554  func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
   555  	var (
   556  		ok bool
   557  		st *stats.InPayload
   558  	)
   559  	if st, ok = d.s.(*stats.InPayload); !ok {
   560  		t.Fatalf("got %T, want InPayload", d.s)
   561  	}
   562  	if d.ctx == nil {
   563  		t.Fatalf("d.ctx = nil, want <non-nil>")
   564  	}
   565  
   566  	var idx *int
   567  	var payloads []proto.Message
   568  	if d.client {
   569  		idx = &e.respIdx
   570  		payloads = e.responses
   571  	} else {
   572  		idx = &e.reqIdx
   573  		payloads = e.requests
   574  	}
   575  
   576  	wantPayload := payloads[*idx]
   577  	if diff := cmp.Diff(wantPayload, st.Payload.(proto.Message), protocmp.Transform()); diff != "" {
   578  		t.Fatalf("unexpected difference in st.Payload (-want +got):\n%s", diff)
   579  	}
   580  	*idx++
   581  	if st.Length != proto.Size(wantPayload) {
   582  		t.Fatalf("st.Length = %v, want %v", st.Length, proto.Size(wantPayload))
   583  	}
   584  
   585  	// Below are sanity checks that WireLength and RecvTime are populated.
   586  	// TODO: check values of WireLength and RecvTime.
   587  	if st.Length > 0 && st.CompressedLength == 0 {
   588  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   589  			st.CompressedLength)
   590  	}
   591  	if st.RecvTime.IsZero() {
   592  		t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
   593  	}
   594  }
   595  
   596  func checkInTrailer(t *testing.T, d *gotData, _ *expectedData) {
   597  	var (
   598  		ok bool
   599  		st *stats.InTrailer
   600  	)
   601  	if st, ok = d.s.(*stats.InTrailer); !ok {
   602  		t.Fatalf("got %T, want InTrailer", d.s)
   603  	}
   604  	if d.ctx == nil {
   605  		t.Fatalf("d.ctx = nil, want <non-nil>")
   606  	}
   607  	if !st.Client {
   608  		t.Fatalf("st IsClient = false, want true")
   609  	}
   610  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   611  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   612  	}
   613  }
   614  
   615  func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
   616  	var (
   617  		ok bool
   618  		st *stats.OutHeader
   619  	)
   620  	if st, ok = d.s.(*stats.OutHeader); !ok {
   621  		t.Fatalf("got %T, want OutHeader", d.s)
   622  	}
   623  	if d.ctx == nil {
   624  		t.Fatalf("d.ctx = nil, want <non-nil>")
   625  	}
   626  	if st.Compression != e.compression {
   627  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   628  	}
   629  	if d.client {
   630  		if st.FullMethod != e.method {
   631  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   632  		}
   633  		if st.RemoteAddr.String() != e.serverAddr {
   634  			t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
   635  		}
   636  		// additional headers might be injected so instead of testing equality, test that all the
   637  		// expected headers keys have the expected header values.
   638  		for key := range testMetadata {
   639  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   640  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   641  			}
   642  		}
   643  
   644  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   645  			if rpcInfo.FullMethodName != st.FullMethod {
   646  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   647  			}
   648  		} else {
   649  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   650  		}
   651  	} else {
   652  		// additional headers might be injected so instead of testing equality, test that all the
   653  		// expected headers keys have the expected header values.
   654  		for key := range testHeaderMetadata {
   655  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   656  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   657  			}
   658  		}
   659  	}
   660  }
   661  
   662  func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
   663  	var (
   664  		ok bool
   665  		st *stats.OutPayload
   666  	)
   667  	if st, ok = d.s.(*stats.OutPayload); !ok {
   668  		t.Fatalf("got %T, want OutPayload", d.s)
   669  	}
   670  	if d.ctx == nil {
   671  		t.Fatalf("d.ctx = nil, want <non-nil>")
   672  	}
   673  
   674  	var idx *int
   675  	var payloads []proto.Message
   676  	if d.client {
   677  		idx = &e.reqIdx
   678  		payloads = e.requests
   679  	} else {
   680  		idx = &e.respIdx
   681  		payloads = e.responses
   682  	}
   683  
   684  	expectedPayload := payloads[*idx]
   685  	if !proto.Equal(st.Payload.(proto.Message), expectedPayload) {
   686  		t.Fatalf("st.Payload = %v, want %v", st.Payload, expectedPayload)
   687  	}
   688  	*idx++
   689  	if st.Length != proto.Size(expectedPayload) {
   690  		t.Fatalf("st.Length = %v, want %v", st.Length, proto.Size(expectedPayload))
   691  	}
   692  
   693  	// Below are sanity checks that Length, CompressedLength and SentTime are populated.
   694  	// TODO: check values of WireLength and SentTime.
   695  	if st.Length > 0 && st.WireLength == 0 {
   696  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   697  			st.WireLength)
   698  	}
   699  	if st.SentTime.IsZero() {
   700  		t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
   701  	}
   702  }
   703  
   704  func checkOutTrailer(t *testing.T, d *gotData, _ *expectedData) {
   705  	var (
   706  		ok bool
   707  		st *stats.OutTrailer
   708  	)
   709  	if st, ok = d.s.(*stats.OutTrailer); !ok {
   710  		t.Fatalf("got %T, want OutTrailer", d.s)
   711  	}
   712  	if d.ctx == nil {
   713  		t.Fatalf("d.ctx = nil, want <non-nil>")
   714  	}
   715  	if st.Client {
   716  		t.Fatalf("st IsClient = true, want false")
   717  	}
   718  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   719  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   720  	}
   721  }
   722  
   723  func checkEnd(t *testing.T, d *gotData, e *expectedData) {
   724  	var (
   725  		ok bool
   726  		st *stats.End
   727  	)
   728  	if st, ok = d.s.(*stats.End); !ok {
   729  		t.Fatalf("got %T, want End", d.s)
   730  	}
   731  	if d.ctx == nil {
   732  		t.Fatalf("d.ctx = nil, want <non-nil>")
   733  	}
   734  	if st.BeginTime.IsZero() {
   735  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   736  	}
   737  	if st.EndTime.IsZero() {
   738  		t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
   739  	}
   740  
   741  	actual, ok := status.FromError(st.Error)
   742  	if !ok {
   743  		t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
   744  	}
   745  
   746  	expectedStatus, _ := status.FromError(e.err)
   747  	if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() {
   748  		t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
   749  	}
   750  
   751  	if st.Client {
   752  		if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   753  			t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   754  		}
   755  	} else {
   756  		if st.Trailer != nil {
   757  			t.Fatalf("st.Trailer = %v, want nil", st.Trailer)
   758  		}
   759  	}
   760  }
   761  
   762  func checkConnBegin(t *testing.T, d *gotData) {
   763  	var (
   764  		ok bool
   765  		st *stats.ConnBegin
   766  	)
   767  	if st, ok = d.s.(*stats.ConnBegin); !ok {
   768  		t.Fatalf("got %T, want ConnBegin", d.s)
   769  	}
   770  	if d.ctx == nil {
   771  		t.Fatalf("d.ctx = nil, want <non-nil>")
   772  	}
   773  	st.IsClient() // TODO remove this.
   774  }
   775  
   776  func checkConnEnd(t *testing.T, d *gotData) {
   777  	var (
   778  		ok bool
   779  		st *stats.ConnEnd
   780  	)
   781  	if st, ok = d.s.(*stats.ConnEnd); !ok {
   782  		t.Fatalf("got %T, want ConnEnd", d.s)
   783  	}
   784  	if d.ctx == nil {
   785  		t.Fatalf("d.ctx = nil, want <non-nil>")
   786  	}
   787  	st.IsClient() // TODO remove this.
   788  }
   789  
   790  type statshandler struct {
   791  	mu      sync.Mutex
   792  	gotRPC  []*gotData
   793  	gotConn []*gotData
   794  }
   795  
   796  func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
   797  	return context.WithValue(ctx, connCtxKey{}, info)
   798  }
   799  
   800  func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
   801  	return context.WithValue(ctx, rpcCtxKey{}, info)
   802  }
   803  
   804  func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
   805  	h.mu.Lock()
   806  	defer h.mu.Unlock()
   807  	h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
   808  }
   809  
   810  func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
   811  	h.mu.Lock()
   812  	defer h.mu.Unlock()
   813  	h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
   814  }
   815  
   816  func checkConnStats(t *testing.T, got []*gotData) {
   817  	if len(got) <= 0 || len(got)%2 != 0 {
   818  		for i, g := range got {
   819  			t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx)
   820  		}
   821  		t.Fatalf("got %v stats, want even positive number", len(got))
   822  	}
   823  	// The first conn stats must be a ConnBegin.
   824  	checkConnBegin(t, got[0])
   825  	// The last conn stats must be a ConnEnd.
   826  	checkConnEnd(t, got[len(got)-1])
   827  }
   828  
   829  func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   830  	if len(got) != len(checkFuncs) {
   831  		for i, g := range got {
   832  			t.Errorf(" - %v, %T", i, g.s)
   833  		}
   834  		t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
   835  	}
   836  
   837  	for i, f := range checkFuncs {
   838  		f(t, got[i], expect)
   839  	}
   840  }
   841  
   842  func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   843  	h := &statshandler{}
   844  	te := newTest(t, tc, nil, []stats.Handler{h})
   845  	te.startServer(&testServer{})
   846  	defer te.tearDown()
   847  
   848  	var (
   849  		reqs   []proto.Message
   850  		resps  []proto.Message
   851  		err    error
   852  		method string
   853  
   854  		isClientStream bool
   855  		isServerStream bool
   856  
   857  		req  proto.Message
   858  		resp proto.Message
   859  		e    error
   860  	)
   861  
   862  	switch cc.callType {
   863  	case unaryRPC:
   864  		method = "/grpc.testing.TestService/UnaryCall"
   865  		req, resp, e = te.doUnaryCall(cc)
   866  		reqs = []proto.Message{req}
   867  		resps = []proto.Message{resp}
   868  		err = e
   869  	case clientStreamRPC:
   870  		method = "/grpc.testing.TestService/StreamingInputCall"
   871  		reqs, resp, e = te.doClientStreamCall(cc)
   872  		resps = []proto.Message{resp}
   873  		err = e
   874  		isClientStream = true
   875  	case serverStreamRPC:
   876  		method = "/grpc.testing.TestService/StreamingOutputCall"
   877  		req, resps, e = te.doServerStreamCall(cc)
   878  		reqs = []proto.Message{req}
   879  		err = e
   880  		isServerStream = true
   881  	case fullDuplexStreamRPC:
   882  		method = "/grpc.testing.TestService/FullDuplexCall"
   883  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
   884  		isClientStream = true
   885  		isServerStream = true
   886  	}
   887  	if cc.success != (err == nil) {
   888  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
   889  	}
   890  	te.cc.Close()
   891  	te.srv.GracefulStop() // Wait for the server to stop.
   892  
   893  	for {
   894  		h.mu.Lock()
   895  		if len(h.gotRPC) >= len(checkFuncs) {
   896  			h.mu.Unlock()
   897  			break
   898  		}
   899  		h.mu.Unlock()
   900  		time.Sleep(10 * time.Millisecond)
   901  	}
   902  
   903  	for {
   904  		h.mu.Lock()
   905  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
   906  			h.mu.Unlock()
   907  			break
   908  		}
   909  		h.mu.Unlock()
   910  		time.Sleep(10 * time.Millisecond)
   911  	}
   912  
   913  	expect := &expectedData{
   914  		serverAddr:     te.srvAddr,
   915  		compression:    tc.compress,
   916  		method:         method,
   917  		requests:       reqs,
   918  		responses:      resps,
   919  		err:            err,
   920  		isClientStream: isClientStream,
   921  		isServerStream: isServerStream,
   922  	}
   923  
   924  	h.mu.Lock()
   925  	checkConnStats(t, h.gotConn)
   926  	h.mu.Unlock()
   927  	checkServerStats(t, h.gotRPC, expect, checkFuncs)
   928  }
   929  
   930  func (s) TestServerStatsUnaryRPC(t *testing.T) {
   931  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   932  		checkInHeader,
   933  		checkBegin,
   934  		checkInPayload,
   935  		checkOutHeader,
   936  		checkOutPayload,
   937  		checkOutTrailer,
   938  		checkEnd,
   939  	})
   940  }
   941  
   942  func (s) TestServerStatsUnaryRPCError(t *testing.T) {
   943  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   944  		checkInHeader,
   945  		checkBegin,
   946  		checkInPayload,
   947  		checkOutHeader,
   948  		checkOutTrailer,
   949  		checkEnd,
   950  	})
   951  }
   952  
   953  func (s) TestServerStatsClientStreamRPC(t *testing.T) {
   954  	count := 5
   955  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   956  		checkInHeader,
   957  		checkBegin,
   958  		checkOutHeader,
   959  	}
   960  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   961  		checkInPayload,
   962  	}
   963  	for i := 0; i < count; i++ {
   964  		checkFuncs = append(checkFuncs, ioPayFuncs...)
   965  	}
   966  	checkFuncs = append(checkFuncs,
   967  		checkOutPayload,
   968  		checkOutTrailer,
   969  		checkEnd,
   970  	)
   971  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs)
   972  }
   973  
   974  func (s) TestServerStatsClientStreamRPCError(t *testing.T) {
   975  	count := 1
   976  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   977  		checkInHeader,
   978  		checkBegin,
   979  		checkOutHeader,
   980  		checkInPayload,
   981  		checkOutTrailer,
   982  		checkEnd,
   983  	})
   984  }
   985  
   986  func (s) TestServerStatsServerStreamRPC(t *testing.T) {
   987  	count := 5
   988  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   989  		checkInHeader,
   990  		checkBegin,
   991  		checkInPayload,
   992  		checkOutHeader,
   993  	}
   994  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   995  		checkOutPayload,
   996  	}
   997  	for i := 0; i < count; i++ {
   998  		checkFuncs = append(checkFuncs, ioPayFuncs...)
   999  	}
  1000  	checkFuncs = append(checkFuncs,
  1001  		checkOutTrailer,
  1002  		checkEnd,
  1003  	)
  1004  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs)
  1005  }
  1006  
  1007  func (s) TestServerStatsServerStreamRPCError(t *testing.T) {
  1008  	count := 5
  1009  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1010  		checkInHeader,
  1011  		checkBegin,
  1012  		checkInPayload,
  1013  		checkOutHeader,
  1014  		checkOutTrailer,
  1015  		checkEnd,
  1016  	})
  1017  }
  1018  
  1019  func (s) TestServerStatsFullDuplexRPC(t *testing.T) {
  1020  	count := 5
  1021  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1022  		checkInHeader,
  1023  		checkBegin,
  1024  		checkOutHeader,
  1025  	}
  1026  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1027  		checkInPayload,
  1028  		checkOutPayload,
  1029  	}
  1030  	for i := 0; i < count; i++ {
  1031  		checkFuncs = append(checkFuncs, ioPayFuncs...)
  1032  	}
  1033  	checkFuncs = append(checkFuncs,
  1034  		checkOutTrailer,
  1035  		checkEnd,
  1036  	)
  1037  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs)
  1038  }
  1039  
  1040  func (s) TestServerStatsFullDuplexRPCError(t *testing.T) {
  1041  	count := 5
  1042  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1043  		checkInHeader,
  1044  		checkBegin,
  1045  		checkOutHeader,
  1046  		checkInPayload,
  1047  		checkOutTrailer,
  1048  		checkEnd,
  1049  	})
  1050  }
  1051  
  1052  type checkFuncWithCount struct {
  1053  	f func(t *testing.T, d *gotData, e *expectedData)
  1054  	c int // expected count
  1055  }
  1056  
  1057  func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
  1058  	var expectLen int
  1059  	for _, v := range checkFuncs {
  1060  		expectLen += v.c
  1061  	}
  1062  	if len(got) != expectLen {
  1063  		for i, g := range got {
  1064  			t.Errorf(" - %v, %T", i, g.s)
  1065  		}
  1066  		t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
  1067  	}
  1068  
  1069  	var tagInfoInCtx *stats.RPCTagInfo
  1070  	for i := 0; i < len(got); i++ {
  1071  		if _, ok := got[i].s.(stats.RPCStats); ok {
  1072  			tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
  1073  			if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
  1074  				t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
  1075  			}
  1076  			tagInfoInCtx = tagInfoInCtxNew
  1077  		}
  1078  	}
  1079  
  1080  	for _, s := range got {
  1081  		switch s.s.(type) {
  1082  		case *stats.Begin:
  1083  			if checkFuncs[begin].c <= 0 {
  1084  				t.Fatalf("unexpected stats: %T", s.s)
  1085  			}
  1086  			checkFuncs[begin].f(t, s, expect)
  1087  			checkFuncs[begin].c--
  1088  		case *stats.OutHeader:
  1089  			if checkFuncs[outHeader].c <= 0 {
  1090  				t.Fatalf("unexpected stats: %T", s.s)
  1091  			}
  1092  			checkFuncs[outHeader].f(t, s, expect)
  1093  			checkFuncs[outHeader].c--
  1094  		case *stats.OutPayload:
  1095  			if checkFuncs[outPayload].c <= 0 {
  1096  				t.Fatalf("unexpected stats: %T", s.s)
  1097  			}
  1098  			checkFuncs[outPayload].f(t, s, expect)
  1099  			checkFuncs[outPayload].c--
  1100  		case *stats.InHeader:
  1101  			if checkFuncs[inHeader].c <= 0 {
  1102  				t.Fatalf("unexpected stats: %T", s.s)
  1103  			}
  1104  			checkFuncs[inHeader].f(t, s, expect)
  1105  			checkFuncs[inHeader].c--
  1106  		case *stats.InPayload:
  1107  			if checkFuncs[inPayload].c <= 0 {
  1108  				t.Fatalf("unexpected stats: %T", s.s)
  1109  			}
  1110  			checkFuncs[inPayload].f(t, s, expect)
  1111  			checkFuncs[inPayload].c--
  1112  		case *stats.InTrailer:
  1113  			if checkFuncs[inTrailer].c <= 0 {
  1114  				t.Fatalf("unexpected stats: %T", s.s)
  1115  			}
  1116  			checkFuncs[inTrailer].f(t, s, expect)
  1117  			checkFuncs[inTrailer].c--
  1118  		case *stats.End:
  1119  			if checkFuncs[end].c <= 0 {
  1120  				t.Fatalf("unexpected stats: %T", s.s)
  1121  			}
  1122  			checkFuncs[end].f(t, s, expect)
  1123  			checkFuncs[end].c--
  1124  		case *stats.ConnBegin:
  1125  			if checkFuncs[connBegin].c <= 0 {
  1126  				t.Fatalf("unexpected stats: %T", s.s)
  1127  			}
  1128  			checkFuncs[connBegin].f(t, s, expect)
  1129  			checkFuncs[connBegin].c--
  1130  		case *stats.ConnEnd:
  1131  			if checkFuncs[connEnd].c <= 0 {
  1132  				t.Fatalf("unexpected stats: %T", s.s)
  1133  			}
  1134  			checkFuncs[connEnd].f(t, s, expect)
  1135  			checkFuncs[connEnd].c--
  1136  		default:
  1137  			t.Fatalf("unexpected stats: %T", s.s)
  1138  		}
  1139  	}
  1140  }
  1141  
  1142  func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
  1143  	h := &statshandler{}
  1144  	te := newTest(t, tc, []stats.Handler{h}, nil)
  1145  	te.startServer(&testServer{})
  1146  	defer te.tearDown()
  1147  
  1148  	var (
  1149  		reqs   []proto.Message
  1150  		resps  []proto.Message
  1151  		method string
  1152  		err    error
  1153  
  1154  		isClientStream bool
  1155  		isServerStream bool
  1156  
  1157  		req  proto.Message
  1158  		resp proto.Message
  1159  		e    error
  1160  	)
  1161  	switch cc.callType {
  1162  	case unaryRPC:
  1163  		method = "/grpc.testing.TestService/UnaryCall"
  1164  		req, resp, e = te.doUnaryCall(cc)
  1165  		reqs = []proto.Message{req}
  1166  		resps = []proto.Message{resp}
  1167  		err = e
  1168  	case clientStreamRPC:
  1169  		method = "/grpc.testing.TestService/StreamingInputCall"
  1170  		reqs, resp, e = te.doClientStreamCall(cc)
  1171  		resps = []proto.Message{resp}
  1172  		err = e
  1173  		isClientStream = true
  1174  	case serverStreamRPC:
  1175  		method = "/grpc.testing.TestService/StreamingOutputCall"
  1176  		req, resps, e = te.doServerStreamCall(cc)
  1177  		reqs = []proto.Message{req}
  1178  		err = e
  1179  		isServerStream = true
  1180  	case fullDuplexStreamRPC:
  1181  		method = "/grpc.testing.TestService/FullDuplexCall"
  1182  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
  1183  		isClientStream = true
  1184  		isServerStream = true
  1185  	}
  1186  	if cc.success != (err == nil) {
  1187  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1188  	}
  1189  	te.cc.Close()
  1190  	te.srv.GracefulStop() // Wait for the server to stop.
  1191  
  1192  	lenRPCStats := 0
  1193  	for _, v := range checkFuncs {
  1194  		lenRPCStats += v.c
  1195  	}
  1196  	for {
  1197  		h.mu.Lock()
  1198  		if len(h.gotRPC) >= lenRPCStats {
  1199  			h.mu.Unlock()
  1200  			break
  1201  		}
  1202  		h.mu.Unlock()
  1203  		time.Sleep(10 * time.Millisecond)
  1204  	}
  1205  
  1206  	for {
  1207  		h.mu.Lock()
  1208  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
  1209  			h.mu.Unlock()
  1210  			break
  1211  		}
  1212  		h.mu.Unlock()
  1213  		time.Sleep(10 * time.Millisecond)
  1214  	}
  1215  
  1216  	expect := &expectedData{
  1217  		serverAddr:     te.srvAddr,
  1218  		compression:    tc.compress,
  1219  		method:         method,
  1220  		requests:       reqs,
  1221  		responses:      resps,
  1222  		failfast:       cc.failfast,
  1223  		err:            err,
  1224  		isClientStream: isClientStream,
  1225  		isServerStream: isServerStream,
  1226  	}
  1227  
  1228  	h.mu.Lock()
  1229  	checkConnStats(t, h.gotConn)
  1230  	h.mu.Unlock()
  1231  	checkClientStats(t, h.gotRPC, expect, checkFuncs)
  1232  }
  1233  
  1234  func (s) TestClientStatsUnaryRPC(t *testing.T) {
  1235  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1236  		begin:      {checkBegin, 1},
  1237  		outHeader:  {checkOutHeader, 1},
  1238  		outPayload: {checkOutPayload, 1},
  1239  		inHeader:   {checkInHeader, 1},
  1240  		inPayload:  {checkInPayload, 1},
  1241  		inTrailer:  {checkInTrailer, 1},
  1242  		end:        {checkEnd, 1},
  1243  	})
  1244  }
  1245  
  1246  func (s) TestClientStatsUnaryRPCError(t *testing.T) {
  1247  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1248  		begin:      {checkBegin, 1},
  1249  		outHeader:  {checkOutHeader, 1},
  1250  		outPayload: {checkOutPayload, 1},
  1251  		inHeader:   {checkInHeader, 1},
  1252  		inTrailer:  {checkInTrailer, 1},
  1253  		end:        {checkEnd, 1},
  1254  	})
  1255  }
  1256  
  1257  func (s) TestClientStatsClientStreamRPC(t *testing.T) {
  1258  	count := 5
  1259  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1260  		begin:      {checkBegin, 1},
  1261  		outHeader:  {checkOutHeader, 1},
  1262  		inHeader:   {checkInHeader, 1},
  1263  		outPayload: {checkOutPayload, count},
  1264  		inTrailer:  {checkInTrailer, 1},
  1265  		inPayload:  {checkInPayload, 1},
  1266  		end:        {checkEnd, 1},
  1267  	})
  1268  }
  1269  
  1270  func (s) TestClientStatsClientStreamRPCError(t *testing.T) {
  1271  	count := 1
  1272  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1273  		begin:      {checkBegin, 1},
  1274  		outHeader:  {checkOutHeader, 1},
  1275  		inHeader:   {checkInHeader, 1},
  1276  		outPayload: {checkOutPayload, 1},
  1277  		inTrailer:  {checkInTrailer, 1},
  1278  		end:        {checkEnd, 1},
  1279  	})
  1280  }
  1281  
  1282  func (s) TestClientStatsServerStreamRPC(t *testing.T) {
  1283  	count := 5
  1284  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1285  		begin:      {checkBegin, 1},
  1286  		outHeader:  {checkOutHeader, 1},
  1287  		outPayload: {checkOutPayload, 1},
  1288  		inHeader:   {checkInHeader, 1},
  1289  		inPayload:  {checkInPayload, count},
  1290  		inTrailer:  {checkInTrailer, 1},
  1291  		end:        {checkEnd, 1},
  1292  	})
  1293  }
  1294  
  1295  func (s) TestClientStatsServerStreamRPCError(t *testing.T) {
  1296  	count := 5
  1297  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1298  		begin:      {checkBegin, 1},
  1299  		outHeader:  {checkOutHeader, 1},
  1300  		outPayload: {checkOutPayload, 1},
  1301  		inHeader:   {checkInHeader, 1},
  1302  		inTrailer:  {checkInTrailer, 1},
  1303  		end:        {checkEnd, 1},
  1304  	})
  1305  }
  1306  
  1307  func (s) TestClientStatsFullDuplexRPC(t *testing.T) {
  1308  	count := 5
  1309  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1310  		begin:      {checkBegin, 1},
  1311  		outHeader:  {checkOutHeader, 1},
  1312  		outPayload: {checkOutPayload, count},
  1313  		inHeader:   {checkInHeader, 1},
  1314  		inPayload:  {checkInPayload, count},
  1315  		inTrailer:  {checkInTrailer, 1},
  1316  		end:        {checkEnd, 1},
  1317  	})
  1318  }
  1319  
  1320  func (s) TestClientStatsFullDuplexRPCError(t *testing.T) {
  1321  	count := 5
  1322  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1323  		begin:      {checkBegin, 1},
  1324  		outHeader:  {checkOutHeader, 1},
  1325  		outPayload: {checkOutPayload, 1},
  1326  		inHeader:   {checkInHeader, 1},
  1327  		inTrailer:  {checkInTrailer, 1},
  1328  		end:        {checkEnd, 1},
  1329  	})
  1330  }
  1331  
  1332  func (s) TestTags(t *testing.T) {
  1333  	b := []byte{5, 2, 4, 3, 1}
  1334  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1335  	defer cancel()
  1336  	ctx := stats.SetTags(tCtx, b)
  1337  	if tg := getOutgoingStats(ctx, "grpc-tags-bin"); !reflect.DeepEqual(tg, b) {
  1338  		t.Errorf("getOutgoingStats(%v, grpc-tags-bin) = %v; want %v", ctx, tg, b)
  1339  	}
  1340  	if tg := stats.Tags(ctx); tg != nil {
  1341  		t.Errorf("Tags(%v) = %v; want nil", ctx, tg)
  1342  	}
  1343  
  1344  	ctx = setIncomingStats(tCtx, "grpc-tags-bin", b)
  1345  	if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) {
  1346  		t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b)
  1347  	}
  1348  	if tg := getOutgoingStats(ctx, "grpc-tags-bin"); tg != nil {
  1349  		t.Errorf("getOutgoingStats(%v, grpc-tags-bin) = %v; want nil", ctx, tg)
  1350  	}
  1351  }
  1352  
  1353  func (s) TestTrace(t *testing.T) {
  1354  	b := []byte{5, 2, 4, 3, 1}
  1355  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1356  	defer cancel()
  1357  	ctx := stats.SetTrace(tCtx, b)
  1358  	if tr := getOutgoingStats(ctx, "grpc-trace-bin"); !reflect.DeepEqual(tr, b) {
  1359  		t.Errorf("getOutgoingStats(%v, grpc-trace-bin) = %v; want %v", ctx, tr, b)
  1360  	}
  1361  	if tr := stats.Trace(ctx); tr != nil {
  1362  		t.Errorf("Trace(%v) = %v; want nil", ctx, tr)
  1363  	}
  1364  
  1365  	ctx = setIncomingStats(tCtx, "grpc-trace-bin", b)
  1366  	if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) {
  1367  		t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b)
  1368  	}
  1369  	if tr := getOutgoingStats(ctx, "grpc-trace-bin"); tr != nil {
  1370  		t.Errorf("getOutgoingStats(%v, grpc-trace-bin) = %v; want nil", ctx, tr)
  1371  	}
  1372  }
  1373  
  1374  func (s) TestMultipleClientStatsHandler(t *testing.T) {
  1375  	h := &statshandler{}
  1376  	tc := &testConfig{compress: ""}
  1377  	te := newTest(t, tc, []stats.Handler{h, h}, nil)
  1378  	te.startServer(&testServer{})
  1379  	defer te.tearDown()
  1380  
  1381  	cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC}
  1382  	_, _, err := te.doUnaryCall(cc)
  1383  	if cc.success != (err == nil) {
  1384  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1385  	}
  1386  	te.cc.Close()
  1387  	te.srv.GracefulStop() // Wait for the server to stop.
  1388  
  1389  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1390  		h.mu.Lock()
  1391  		if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok && len(h.gotRPC) == 12 {
  1392  			h.mu.Unlock()
  1393  			break
  1394  		}
  1395  		h.mu.Unlock()
  1396  		time.Sleep(10 * time.Millisecond)
  1397  	}
  1398  
  1399  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1400  		h.mu.Lock()
  1401  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok && len(h.gotConn) == 4 {
  1402  			h.mu.Unlock()
  1403  			break
  1404  		}
  1405  		h.mu.Unlock()
  1406  		time.Sleep(10 * time.Millisecond)
  1407  	}
  1408  
  1409  	// Each RPC generates 6 stats events on the client-side, times 2 StatsHandler
  1410  	if len(h.gotRPC) != 12 {
  1411  		t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12)
  1412  	}
  1413  
  1414  	// Each connection generates 4 conn events on the client-side, times 2 StatsHandler
  1415  	if len(h.gotConn) != 4 {
  1416  		t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
  1417  	}
  1418  }
  1419  
  1420  func (s) TestMultipleServerStatsHandler(t *testing.T) {
  1421  	h := &statshandler{}
  1422  	tc := &testConfig{compress: ""}
  1423  	te := newTest(t, tc, nil, []stats.Handler{h, h})
  1424  	te.startServer(&testServer{})
  1425  	defer te.tearDown()
  1426  
  1427  	cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC}
  1428  	_, _, err := te.doUnaryCall(cc)
  1429  	if cc.success != (err == nil) {
  1430  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1431  	}
  1432  	te.cc.Close()
  1433  	te.srv.GracefulStop() // Wait for the server to stop.
  1434  
  1435  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1436  		h.mu.Lock()
  1437  		if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok {
  1438  			h.mu.Unlock()
  1439  			break
  1440  		}
  1441  		h.mu.Unlock()
  1442  		time.Sleep(10 * time.Millisecond)
  1443  	}
  1444  
  1445  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1446  		h.mu.Lock()
  1447  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
  1448  			h.mu.Unlock()
  1449  			break
  1450  		}
  1451  		h.mu.Unlock()
  1452  		time.Sleep(10 * time.Millisecond)
  1453  	}
  1454  
  1455  	// Each RPC generates 6 stats events on the server-side, times 2 StatsHandler
  1456  	if len(h.gotRPC) != 12 {
  1457  		t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12)
  1458  	}
  1459  
  1460  	// Each connection generates 4 conn events on the server-side, times 2 StatsHandler
  1461  	if len(h.gotConn) != 4 {
  1462  		t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
  1463  	}
  1464  }
  1465  
  1466  // TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
  1467  // gets access to a Server on the server side, and thus the method that the
  1468  // server owns which specifies whether a method is made or not. The test sets up
  1469  // a server with a unary call and full duplex call configured, and makes an RPC.
  1470  // Within the stats handler, asking the server whether unary or duplex method
  1471  // names are registered should return true, and any other query should return
  1472  // false.
  1473  func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
  1474  	wg := sync.WaitGroup{}
  1475  	wg.Add(1)
  1476  	stubStatsHandler := &testutils.StubStatsHandler{
  1477  		TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
  1478  			// OpenTelemetry instrumentation needs the passed in Server to determine if
  1479  			// methods are registered in different handle calls in to record metrics.
  1480  			// This tag RPC call context gets passed into every handle call, so can
  1481  			// assert once here, since it maps to all the handle RPC calls that come
  1482  			// after. These internal calls will be how the OpenTelemetry instrumentation
  1483  			// component accesses this server and the subsequent helper on the server.
  1484  			server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx)
  1485  			if server == nil {
  1486  				t.Errorf("stats handler received ctx has no server present")
  1487  			}
  1488  			isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)
  1489  			// /s/m and s/m are valid.
  1490  			if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") {
  1491  				t.Errorf("UnaryCall should be a registered method according to server")
  1492  			}
  1493  			if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") {
  1494  				t.Errorf("FullDuplexCall should be a registered method according to server")
  1495  			}
  1496  			if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") {
  1497  				t.Errorf("DoesNotExistCall should not be a registered method according to server")
  1498  			}
  1499  			if isRegisteredMethod(server, "/unknownService/UnaryCall") {
  1500  				t.Errorf("/unknownService/UnaryCall should not be a registered method according to server")
  1501  			}
  1502  			wg.Done()
  1503  			return ctx
  1504  		},
  1505  	}
  1506  	ss := &stubserver.StubServer{
  1507  		UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
  1508  			return &testpb.SimpleResponse{}, nil
  1509  		},
  1510  	}
  1511  	if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil {
  1512  		t.Fatalf("Error starting endpoint server: %v", err)
  1513  	}
  1514  	defer ss.Stop()
  1515  
  1516  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1517  	defer cancel()
  1518  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
  1519  		t.Fatalf("Unexpected error from UnaryCall: %v", err)
  1520  	}
  1521  	wg.Wait()
  1522  }