google.golang.org/grpc@v1.72.2/stats/stats_test.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package stats_test
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"reflect"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/google/go-cmp/cmp"
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/credentials/insecure"
    34  	"google.golang.org/grpc/internal"
    35  	"google.golang.org/grpc/internal/grpctest"
    36  	"google.golang.org/grpc/internal/stubserver"
    37  	"google.golang.org/grpc/internal/testutils"
    38  	"google.golang.org/grpc/metadata"
    39  	"google.golang.org/grpc/stats"
    40  	"google.golang.org/grpc/status"
    41  	"google.golang.org/protobuf/proto"
    42  	"google.golang.org/protobuf/testing/protocmp"
    43  
    44  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    45  	testpb "google.golang.org/grpc/interop/grpc_testing"
    46  )
    47  
    48  const defaultTestTimeout = 10 * time.Second
    49  
    50  type s struct {
    51  	grpctest.Tester
    52  }
    53  
    54  func Test(t *testing.T) {
    55  	grpctest.RunSubTests(t, s{})
    56  }
    57  
    58  func init() {
    59  	grpc.EnableTracing = false
    60  }
    61  
    62  type connCtxKey struct{}
    63  type rpcCtxKey struct{}
    64  
    65  var (
    66  	// For headers sent to server:
    67  	testMetadata = metadata.MD{
    68  		"key1":       []string{"value1"},
    69  		"key2":       []string{"value2"},
    70  		"user-agent": []string{fmt.Sprintf("test/0.0.1 grpc-go/%s", grpc.Version)},
    71  	}
    72  	// For headers sent from server:
    73  	testHeaderMetadata = metadata.MD{
    74  		"hkey1": []string{"headerValue1"},
    75  		"hkey2": []string{"headerValue2"},
    76  	}
    77  	// For trailers sent from server:
    78  	testTrailerMetadata = metadata.MD{
    79  		"tkey1": []string{"trailerValue1"},
    80  		"tkey2": []string{"trailerValue2"},
    81  	}
    82  	// The id for which the service handler should return error.
    83  	errorID int32 = 32202
    84  )
    85  
    86  func idToPayload(id int32) *testpb.Payload {
    87  	return &testpb.Payload{Body: []byte{byte(id), byte(id >> 8), byte(id >> 16), byte(id >> 24)}}
    88  }
    89  
    90  func payloadToID(p *testpb.Payload) int32 {
    91  	if p == nil || len(p.Body) != 4 {
    92  		panic("invalid payload")
    93  	}
    94  	return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24
    95  }
    96  
    97  func setIncomingStats(ctx context.Context, mdKey string, b []byte) context.Context {
    98  	md, ok := metadata.FromIncomingContext(ctx)
    99  	if !ok {
   100  		md = metadata.MD{}
   101  	}
   102  	md.Set(mdKey, string(b))
   103  	return metadata.NewIncomingContext(ctx, md)
   104  }
   105  
   106  func getOutgoingStats(ctx context.Context, mdKey string) []byte {
   107  	md, ok := metadata.FromOutgoingContext(ctx)
   108  	if !ok {
   109  		return nil
   110  	}
   111  	tagValues := md.Get(mdKey)
   112  	if len(tagValues) == 0 {
   113  		return nil
   114  	}
   115  	return []byte(tagValues[len(tagValues)-1])
   116  }
   117  
   118  type testServer struct {
   119  	testgrpc.UnimplementedTestServiceServer
   120  }
   121  
   122  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   123  	if err := grpc.SendHeader(ctx, testHeaderMetadata); err != nil {
   124  		return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", testHeaderMetadata, err)
   125  	}
   126  	if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
   127  		return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
   128  	}
   129  
   130  	if id := payloadToID(in.Payload); id == errorID {
   131  		return nil, fmt.Errorf("got error id: %v", id)
   132  	}
   133  
   134  	return &testpb.SimpleResponse{Payload: in.Payload}, nil
   135  }
   136  
   137  func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
   138  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   139  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   140  	}
   141  	stream.SetTrailer(testTrailerMetadata)
   142  	for {
   143  		in, err := stream.Recv()
   144  		if err == io.EOF {
   145  			// read done.
   146  			return nil
   147  		}
   148  		if err != nil {
   149  			return err
   150  		}
   151  
   152  		if id := payloadToID(in.Payload); id == errorID {
   153  			return fmt.Errorf("got error id: %v", id)
   154  		}
   155  
   156  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   157  			return err
   158  		}
   159  	}
   160  }
   161  
   162  func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
   163  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   164  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   165  	}
   166  	stream.SetTrailer(testTrailerMetadata)
   167  	for {
   168  		in, err := stream.Recv()
   169  		if err == io.EOF {
   170  			// read done.
   171  			return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
   172  		}
   173  		if err != nil {
   174  			return err
   175  		}
   176  
   177  		if id := payloadToID(in.Payload); id == errorID {
   178  			return fmt.Errorf("got error id: %v", id)
   179  		}
   180  	}
   181  }
   182  
   183  func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
   184  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   185  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   186  	}
   187  	stream.SetTrailer(testTrailerMetadata)
   188  
   189  	if id := payloadToID(in.Payload); id == errorID {
   190  		return fmt.Errorf("got error id: %v", id)
   191  	}
   192  
   193  	for i := 0; i < 5; i++ {
   194  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   195  			return err
   196  		}
   197  	}
   198  	return nil
   199  }
   200  
   201  // test is an end-to-end test. It should be created with the newTest
   202  // func, modified as needed, and then started with its startServer method.
   203  // It should be cleaned up with the tearDown method.
   204  type test struct {
   205  	t                   *testing.T
   206  	compress            string
   207  	clientStatsHandlers []stats.Handler
   208  	serverStatsHandlers []stats.Handler
   209  
   210  	testServer testgrpc.TestServiceServer // nil means none
   211  	// srv and srvAddr are set once startServer is called.
   212  	srv     *grpc.Server
   213  	srvAddr string
   214  
   215  	cc *grpc.ClientConn // nil until requested via clientConn
   216  }
   217  
   218  func (te *test) tearDown() {
   219  	if te.cc != nil {
   220  		te.cc.Close()
   221  		te.cc = nil
   222  	}
   223  	te.srv.Stop()
   224  }
   225  
   226  type testConfig struct {
   227  	compress string
   228  }
   229  
   230  // newTest returns a new test using the provided testing.T and
   231  // environment.  It is returned with default values. Tests should
   232  // modify it before calling its startServer and clientConn methods.
   233  func newTest(t *testing.T, tc *testConfig, chs []stats.Handler, shs []stats.Handler) *test {
   234  	te := &test{
   235  		t:                   t,
   236  		compress:            tc.compress,
   237  		clientStatsHandlers: chs,
   238  		serverStatsHandlers: shs,
   239  	}
   240  	return te
   241  }
   242  
   243  // startServer starts a gRPC server listening. Callers should defer a
   244  // call to te.tearDown to clean up.
   245  func (te *test) startServer(ts testgrpc.TestServiceServer) {
   246  	te.testServer = ts
   247  	lis, err := net.Listen("tcp", "localhost:0")
   248  	if err != nil {
   249  		te.t.Fatalf("Failed to listen: %v", err)
   250  	}
   251  	var opts []grpc.ServerOption
   252  	if te.compress == "gzip" {
   253  		opts = append(opts,
   254  			grpc.RPCCompressor(grpc.NewGZIPCompressor()),
   255  			grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
   256  		)
   257  	}
   258  	for _, sh := range te.serverStatsHandlers {
   259  		opts = append(opts, grpc.StatsHandler(sh))
   260  	}
   261  	s := grpc.NewServer(opts...)
   262  	te.srv = s
   263  	if te.testServer != nil {
   264  		testgrpc.RegisterTestServiceServer(s, te.testServer)
   265  	}
   266  
   267  	go s.Serve(lis)
   268  	te.srvAddr = lis.Addr().String()
   269  }
   270  
   271  func (te *test) clientConn() *grpc.ClientConn {
   272  	if te.cc != nil {
   273  		return te.cc
   274  	}
   275  	opts := []grpc.DialOption{
   276  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   277  		grpc.WithBlock(),
   278  		grpc.WithUserAgent("test/0.0.1"),
   279  	}
   280  	if te.compress == "gzip" {
   281  		opts = append(opts,
   282  			grpc.WithCompressor(grpc.NewGZIPCompressor()),
   283  			grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
   284  		)
   285  	}
   286  	for _, sh := range te.clientStatsHandlers {
   287  		opts = append(opts, grpc.WithStatsHandler(sh))
   288  	}
   289  
   290  	var err error
   291  	te.cc, err = grpc.Dial(te.srvAddr, opts...)
   292  	if err != nil {
   293  		te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
   294  	}
   295  	return te.cc
   296  }
   297  
   298  type rpcType int
   299  
   300  const (
   301  	unaryRPC rpcType = iota
   302  	clientStreamRPC
   303  	serverStreamRPC
   304  	fullDuplexStreamRPC
   305  )
   306  
   307  type rpcConfig struct {
   308  	count    int  // Number of requests and responses for streaming RPCs.
   309  	success  bool // Whether the RPC should succeed or return error.
   310  	failfast bool
   311  	callType rpcType // Type of RPC.
   312  }
   313  
   314  func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
   315  	var (
   316  		resp *testpb.SimpleResponse
   317  		req  *testpb.SimpleRequest
   318  		err  error
   319  	)
   320  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   321  	if c.success {
   322  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)}
   323  	} else {
   324  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID)}
   325  	}
   326  
   327  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   328  	defer cancel()
   329  	resp, err = tc.UnaryCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   330  	return req, resp, err
   331  }
   332  
   333  func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []proto.Message, error) {
   334  	var (
   335  		reqs  []proto.Message
   336  		resps []proto.Message
   337  		err   error
   338  	)
   339  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   340  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   341  	defer cancel()
   342  	stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   343  	if err != nil {
   344  		return reqs, resps, err
   345  	}
   346  	var startID int32
   347  	if !c.success {
   348  		startID = errorID
   349  	}
   350  	for i := 0; i < c.count; i++ {
   351  		req := &testpb.StreamingOutputCallRequest{
   352  			Payload: idToPayload(int32(i) + startID),
   353  		}
   354  		reqs = append(reqs, req)
   355  		if err = stream.Send(req); err != nil {
   356  			return reqs, resps, err
   357  		}
   358  		var resp *testpb.StreamingOutputCallResponse
   359  		if resp, err = stream.Recv(); err != nil {
   360  			return reqs, resps, err
   361  		}
   362  		resps = append(resps, resp)
   363  	}
   364  	if err = stream.CloseSend(); err != nil && err != io.EOF {
   365  		return reqs, resps, err
   366  	}
   367  	if _, err = stream.Recv(); err != io.EOF {
   368  		return reqs, resps, err
   369  	}
   370  
   371  	return reqs, resps, nil
   372  }
   373  
   374  func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.StreamingInputCallResponse, error) {
   375  	var (
   376  		reqs []proto.Message
   377  		resp *testpb.StreamingInputCallResponse
   378  		err  error
   379  	)
   380  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   381  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   382  	defer cancel()
   383  	stream, err := tc.StreamingInputCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   384  	if err != nil {
   385  		return reqs, resp, err
   386  	}
   387  	var startID int32
   388  	if !c.success {
   389  		startID = errorID
   390  	}
   391  	for i := 0; i < c.count; i++ {
   392  		req := &testpb.StreamingInputCallRequest{
   393  			Payload: idToPayload(int32(i) + startID),
   394  		}
   395  		reqs = append(reqs, req)
   396  		if err = stream.Send(req); err != nil {
   397  			return reqs, resp, err
   398  		}
   399  	}
   400  	resp, err = stream.CloseAndRecv()
   401  	return reqs, resp, err
   402  }
   403  
   404  func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallRequest, []proto.Message, error) {
   405  	var (
   406  		req   *testpb.StreamingOutputCallRequest
   407  		resps []proto.Message
   408  		err   error
   409  	)
   410  
   411  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   412  
   413  	var startID int32
   414  	if !c.success {
   415  		startID = errorID
   416  	}
   417  	req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)}
   418  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   419  	defer cancel()
   420  	stream, err := tc.StreamingOutputCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   421  	if err != nil {
   422  		return req, resps, err
   423  	}
   424  	for {
   425  		var resp *testpb.StreamingOutputCallResponse
   426  		resp, err := stream.Recv()
   427  		if err == io.EOF {
   428  			return req, resps, nil
   429  		} else if err != nil {
   430  			return req, resps, err
   431  		}
   432  		resps = append(resps, resp)
   433  	}
   434  }
   435  
   436  type expectedData struct {
   437  	method         string
   438  	isClientStream bool
   439  	isServerStream bool
   440  	serverAddr     string
   441  	compression    string
   442  	reqIdx         int
   443  	requests       []proto.Message
   444  	respIdx        int
   445  	responses      []proto.Message
   446  	err            error
   447  	failfast       bool
   448  }
   449  
   450  type gotData struct {
   451  	ctx    context.Context
   452  	client bool
   453  	s      any // This could be RPCStats or ConnStats.
   454  }
   455  
   456  const (
   457  	begin int = iota
   458  	end
   459  	inPayload
   460  	inHeader
   461  	inTrailer
   462  	outPayload
   463  	outHeader
   464  	// TODO: test outTrailer ?
   465  	connBegin
   466  	connEnd
   467  )
   468  
   469  func checkBegin(t *testing.T, d *gotData, e *expectedData) {
   470  	var (
   471  		ok bool
   472  		st *stats.Begin
   473  	)
   474  	if st, ok = d.s.(*stats.Begin); !ok {
   475  		t.Fatalf("got %T, want Begin", d.s)
   476  	}
   477  	if d.ctx == nil {
   478  		t.Fatalf("d.ctx = nil, want <non-nil>")
   479  	}
   480  	if st.BeginTime.IsZero() {
   481  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   482  	}
   483  	if d.client {
   484  		if st.FailFast != e.failfast {
   485  			t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
   486  		}
   487  	}
   488  	if st.IsClientStream != e.isClientStream {
   489  		t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream)
   490  	}
   491  	if st.IsServerStream != e.isServerStream {
   492  		t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream)
   493  	}
   494  }
   495  
   496  func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
   497  	var (
   498  		ok bool
   499  		st *stats.InHeader
   500  	)
   501  	if st, ok = d.s.(*stats.InHeader); !ok {
   502  		t.Fatalf("got %T, want InHeader", d.s)
   503  	}
   504  	if d.ctx == nil {
   505  		t.Fatalf("d.ctx = nil, want <non-nil>")
   506  	}
   507  	if st.Compression != e.compression {
   508  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   509  	}
   510  	if d.client {
   511  		// additional headers might be injected so instead of testing equality, test that all the
   512  		// expected headers keys have the expected header values.
   513  		for key := range testHeaderMetadata {
   514  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   515  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   516  			}
   517  		}
   518  	} else {
   519  		if st.FullMethod != e.method {
   520  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   521  		}
   522  		if st.LocalAddr.String() != e.serverAddr {
   523  			t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
   524  		}
   525  		// additional headers might be injected so instead of testing equality, test that all the
   526  		// expected headers keys have the expected header values.
   527  		for key := range testMetadata {
   528  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   529  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   530  			}
   531  		}
   532  
   533  		if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
   534  			if connInfo.RemoteAddr != st.RemoteAddr {
   535  				t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
   536  			}
   537  			if connInfo.LocalAddr != st.LocalAddr {
   538  				t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
   539  			}
   540  		} else {
   541  			t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
   542  		}
   543  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   544  			if rpcInfo.FullMethodName != st.FullMethod {
   545  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   546  			}
   547  		} else {
   548  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   549  		}
   550  	}
   551  }
   552  
   553  func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
   554  	var (
   555  		ok bool
   556  		st *stats.InPayload
   557  	)
   558  	if st, ok = d.s.(*stats.InPayload); !ok {
   559  		t.Fatalf("got %T, want InPayload", d.s)
   560  	}
   561  	if d.ctx == nil {
   562  		t.Fatalf("d.ctx = nil, want <non-nil>")
   563  	}
   564  
   565  	var idx *int
   566  	var payloads []proto.Message
   567  	if d.client {
   568  		idx = &e.respIdx
   569  		payloads = e.responses
   570  	} else {
   571  		idx = &e.reqIdx
   572  		payloads = e.requests
   573  	}
   574  
   575  	wantPayload := payloads[*idx]
   576  	if diff := cmp.Diff(wantPayload, st.Payload.(proto.Message), protocmp.Transform()); diff != "" {
   577  		t.Fatalf("unexpected difference in st.Payload (-want +got):\n%s", diff)
   578  	}
   579  	*idx++
   580  	if st.Length != proto.Size(wantPayload) {
   581  		t.Fatalf("st.Length = %v, want %v", st.Length, proto.Size(wantPayload))
   582  	}
   583  
   584  	// Below are sanity checks that WireLength and RecvTime are populated.
   585  	// TODO: check values of WireLength and RecvTime.
   586  	if st.Length > 0 && st.CompressedLength == 0 {
   587  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   588  			st.CompressedLength)
   589  	}
   590  	if st.RecvTime.IsZero() {
   591  		t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
   592  	}
   593  }
   594  
   595  func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
   596  	var (
   597  		ok bool
   598  		st *stats.InTrailer
   599  	)
   600  	if st, ok = d.s.(*stats.InTrailer); !ok {
   601  		t.Fatalf("got %T, want InTrailer", d.s)
   602  	}
   603  	if d.ctx == nil {
   604  		t.Fatalf("d.ctx = nil, want <non-nil>")
   605  	}
   606  	if !st.Client {
   607  		t.Fatalf("st IsClient = false, want true")
   608  	}
   609  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   610  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   611  	}
   612  }
   613  
   614  func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
   615  	var (
   616  		ok bool
   617  		st *stats.OutHeader
   618  	)
   619  	if st, ok = d.s.(*stats.OutHeader); !ok {
   620  		t.Fatalf("got %T, want OutHeader", d.s)
   621  	}
   622  	if d.ctx == nil {
   623  		t.Fatalf("d.ctx = nil, want <non-nil>")
   624  	}
   625  	if st.Compression != e.compression {
   626  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   627  	}
   628  	if d.client {
   629  		if st.FullMethod != e.method {
   630  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   631  		}
   632  		if st.RemoteAddr.String() != e.serverAddr {
   633  			t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
   634  		}
   635  		// additional headers might be injected so instead of testing equality, test that all the
   636  		// expected headers keys have the expected header values.
   637  		for key := range testMetadata {
   638  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   639  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   640  			}
   641  		}
   642  
   643  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   644  			if rpcInfo.FullMethodName != st.FullMethod {
   645  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   646  			}
   647  		} else {
   648  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   649  		}
   650  	} else {
   651  		// additional headers might be injected so instead of testing equality, test that all the
   652  		// expected headers keys have the expected header values.
   653  		for key := range testHeaderMetadata {
   654  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   655  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   656  			}
   657  		}
   658  	}
   659  }
   660  
   661  func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
   662  	var (
   663  		ok bool
   664  		st *stats.OutPayload
   665  	)
   666  	if st, ok = d.s.(*stats.OutPayload); !ok {
   667  		t.Fatalf("got %T, want OutPayload", d.s)
   668  	}
   669  	if d.ctx == nil {
   670  		t.Fatalf("d.ctx = nil, want <non-nil>")
   671  	}
   672  
   673  	var idx *int
   674  	var payloads []proto.Message
   675  	if d.client {
   676  		idx = &e.reqIdx
   677  		payloads = e.requests
   678  	} else {
   679  		idx = &e.respIdx
   680  		payloads = e.responses
   681  	}
   682  
   683  	expectedPayload := payloads[*idx]
   684  	if !proto.Equal(st.Payload.(proto.Message), expectedPayload) {
   685  		t.Fatalf("st.Payload = %v, want %v", st.Payload, expectedPayload)
   686  	}
   687  	*idx++
   688  	if st.Length != proto.Size(expectedPayload) {
   689  		t.Fatalf("st.Length = %v, want %v", st.Length, proto.Size(expectedPayload))
   690  	}
   691  
   692  	// Below are sanity checks that Length, CompressedLength and SentTime are populated.
   693  	// TODO: check values of WireLength and SentTime.
   694  	if st.Length > 0 && st.WireLength == 0 {
   695  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   696  			st.WireLength)
   697  	}
   698  	if st.SentTime.IsZero() {
   699  		t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
   700  	}
   701  }
   702  
   703  func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
   704  	var (
   705  		ok bool
   706  		st *stats.OutTrailer
   707  	)
   708  	if st, ok = d.s.(*stats.OutTrailer); !ok {
   709  		t.Fatalf("got %T, want OutTrailer", d.s)
   710  	}
   711  	if d.ctx == nil {
   712  		t.Fatalf("d.ctx = nil, want <non-nil>")
   713  	}
   714  	if st.Client {
   715  		t.Fatalf("st IsClient = true, want false")
   716  	}
   717  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   718  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   719  	}
   720  }
   721  
   722  func checkEnd(t *testing.T, d *gotData, e *expectedData) {
   723  	var (
   724  		ok bool
   725  		st *stats.End
   726  	)
   727  	if st, ok = d.s.(*stats.End); !ok {
   728  		t.Fatalf("got %T, want End", d.s)
   729  	}
   730  	if d.ctx == nil {
   731  		t.Fatalf("d.ctx = nil, want <non-nil>")
   732  	}
   733  	if st.BeginTime.IsZero() {
   734  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   735  	}
   736  	if st.EndTime.IsZero() {
   737  		t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
   738  	}
   739  
   740  	actual, ok := status.FromError(st.Error)
   741  	if !ok {
   742  		t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
   743  	}
   744  
   745  	expectedStatus, _ := status.FromError(e.err)
   746  	if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() {
   747  		t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
   748  	}
   749  
   750  	if st.Client {
   751  		if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   752  			t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   753  		}
   754  	} else {
   755  		if st.Trailer != nil {
   756  			t.Fatalf("st.Trailer = %v, want nil", st.Trailer)
   757  		}
   758  	}
   759  }
   760  
   761  func checkConnBegin(t *testing.T, d *gotData) {
   762  	var (
   763  		ok bool
   764  		st *stats.ConnBegin
   765  	)
   766  	if st, ok = d.s.(*stats.ConnBegin); !ok {
   767  		t.Fatalf("got %T, want ConnBegin", d.s)
   768  	}
   769  	if d.ctx == nil {
   770  		t.Fatalf("d.ctx = nil, want <non-nil>")
   771  	}
   772  	st.IsClient() // TODO remove this.
   773  }
   774  
   775  func checkConnEnd(t *testing.T, d *gotData) {
   776  	var (
   777  		ok bool
   778  		st *stats.ConnEnd
   779  	)
   780  	if st, ok = d.s.(*stats.ConnEnd); !ok {
   781  		t.Fatalf("got %T, want ConnEnd", d.s)
   782  	}
   783  	if d.ctx == nil {
   784  		t.Fatalf("d.ctx = nil, want <non-nil>")
   785  	}
   786  	st.IsClient() // TODO remove this.
   787  }
   788  
   789  type statshandler struct {
   790  	mu      sync.Mutex
   791  	gotRPC  []*gotData
   792  	gotConn []*gotData
   793  }
   794  
   795  func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
   796  	return context.WithValue(ctx, connCtxKey{}, info)
   797  }
   798  
   799  func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
   800  	return context.WithValue(ctx, rpcCtxKey{}, info)
   801  }
   802  
   803  func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
   804  	h.mu.Lock()
   805  	defer h.mu.Unlock()
   806  	h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
   807  }
   808  
   809  func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
   810  	h.mu.Lock()
   811  	defer h.mu.Unlock()
   812  	h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
   813  }
   814  
   815  func checkConnStats(t *testing.T, got []*gotData) {
   816  	if len(got) <= 0 || len(got)%2 != 0 {
   817  		for i, g := range got {
   818  			t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx)
   819  		}
   820  		t.Fatalf("got %v stats, want even positive number", len(got))
   821  	}
   822  	// The first conn stats must be a ConnBegin.
   823  	checkConnBegin(t, got[0])
   824  	// The last conn stats must be a ConnEnd.
   825  	checkConnEnd(t, got[len(got)-1])
   826  }
   827  
   828  func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   829  	if len(got) != len(checkFuncs) {
   830  		for i, g := range got {
   831  			t.Errorf(" - %v, %T", i, g.s)
   832  		}
   833  		t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
   834  	}
   835  
   836  	for i, f := range checkFuncs {
   837  		f(t, got[i], expect)
   838  	}
   839  }
   840  
   841  func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   842  	h := &statshandler{}
   843  	te := newTest(t, tc, nil, []stats.Handler{h})
   844  	te.startServer(&testServer{})
   845  	defer te.tearDown()
   846  
   847  	var (
   848  		reqs   []proto.Message
   849  		resps  []proto.Message
   850  		err    error
   851  		method string
   852  
   853  		isClientStream bool
   854  		isServerStream bool
   855  
   856  		req  proto.Message
   857  		resp proto.Message
   858  		e    error
   859  	)
   860  
   861  	switch cc.callType {
   862  	case unaryRPC:
   863  		method = "/grpc.testing.TestService/UnaryCall"
   864  		req, resp, e = te.doUnaryCall(cc)
   865  		reqs = []proto.Message{req}
   866  		resps = []proto.Message{resp}
   867  		err = e
   868  	case clientStreamRPC:
   869  		method = "/grpc.testing.TestService/StreamingInputCall"
   870  		reqs, resp, e = te.doClientStreamCall(cc)
   871  		resps = []proto.Message{resp}
   872  		err = e
   873  		isClientStream = true
   874  	case serverStreamRPC:
   875  		method = "/grpc.testing.TestService/StreamingOutputCall"
   876  		req, resps, e = te.doServerStreamCall(cc)
   877  		reqs = []proto.Message{req}
   878  		err = e
   879  		isServerStream = true
   880  	case fullDuplexStreamRPC:
   881  		method = "/grpc.testing.TestService/FullDuplexCall"
   882  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
   883  		isClientStream = true
   884  		isServerStream = true
   885  	}
   886  	if cc.success != (err == nil) {
   887  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
   888  	}
   889  	te.cc.Close()
   890  	te.srv.GracefulStop() // Wait for the server to stop.
   891  
   892  	for {
   893  		h.mu.Lock()
   894  		if len(h.gotRPC) >= len(checkFuncs) {
   895  			h.mu.Unlock()
   896  			break
   897  		}
   898  		h.mu.Unlock()
   899  		time.Sleep(10 * time.Millisecond)
   900  	}
   901  
   902  	for {
   903  		h.mu.Lock()
   904  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
   905  			h.mu.Unlock()
   906  			break
   907  		}
   908  		h.mu.Unlock()
   909  		time.Sleep(10 * time.Millisecond)
   910  	}
   911  
   912  	expect := &expectedData{
   913  		serverAddr:     te.srvAddr,
   914  		compression:    tc.compress,
   915  		method:         method,
   916  		requests:       reqs,
   917  		responses:      resps,
   918  		err:            err,
   919  		isClientStream: isClientStream,
   920  		isServerStream: isServerStream,
   921  	}
   922  
   923  	h.mu.Lock()
   924  	checkConnStats(t, h.gotConn)
   925  	h.mu.Unlock()
   926  	checkServerStats(t, h.gotRPC, expect, checkFuncs)
   927  }
   928  
   929  func (s) TestServerStatsUnaryRPC(t *testing.T) {
   930  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   931  		checkInHeader,
   932  		checkBegin,
   933  		checkInPayload,
   934  		checkOutHeader,
   935  		checkOutPayload,
   936  		checkOutTrailer,
   937  		checkEnd,
   938  	})
   939  }
   940  
   941  func (s) TestServerStatsUnaryRPCError(t *testing.T) {
   942  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   943  		checkInHeader,
   944  		checkBegin,
   945  		checkInPayload,
   946  		checkOutHeader,
   947  		checkOutTrailer,
   948  		checkEnd,
   949  	})
   950  }
   951  
   952  func (s) TestServerStatsClientStreamRPC(t *testing.T) {
   953  	count := 5
   954  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   955  		checkInHeader,
   956  		checkBegin,
   957  		checkOutHeader,
   958  	}
   959  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   960  		checkInPayload,
   961  	}
   962  	for i := 0; i < count; i++ {
   963  		checkFuncs = append(checkFuncs, ioPayFuncs...)
   964  	}
   965  	checkFuncs = append(checkFuncs,
   966  		checkOutPayload,
   967  		checkOutTrailer,
   968  		checkEnd,
   969  	)
   970  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs)
   971  }
   972  
   973  func (s) TestServerStatsClientStreamRPCError(t *testing.T) {
   974  	count := 1
   975  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   976  		checkInHeader,
   977  		checkBegin,
   978  		checkOutHeader,
   979  		checkInPayload,
   980  		checkOutTrailer,
   981  		checkEnd,
   982  	})
   983  }
   984  
   985  func (s) TestServerStatsServerStreamRPC(t *testing.T) {
   986  	count := 5
   987  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   988  		checkInHeader,
   989  		checkBegin,
   990  		checkInPayload,
   991  		checkOutHeader,
   992  	}
   993  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   994  		checkOutPayload,
   995  	}
   996  	for i := 0; i < count; i++ {
   997  		checkFuncs = append(checkFuncs, ioPayFuncs...)
   998  	}
   999  	checkFuncs = append(checkFuncs,
  1000  		checkOutTrailer,
  1001  		checkEnd,
  1002  	)
  1003  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs)
  1004  }
  1005  
  1006  func (s) TestServerStatsServerStreamRPCError(t *testing.T) {
  1007  	count := 5
  1008  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1009  		checkInHeader,
  1010  		checkBegin,
  1011  		checkInPayload,
  1012  		checkOutHeader,
  1013  		checkOutTrailer,
  1014  		checkEnd,
  1015  	})
  1016  }
  1017  
  1018  func (s) TestServerStatsFullDuplexRPC(t *testing.T) {
  1019  	count := 5
  1020  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1021  		checkInHeader,
  1022  		checkBegin,
  1023  		checkOutHeader,
  1024  	}
  1025  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1026  		checkInPayload,
  1027  		checkOutPayload,
  1028  	}
  1029  	for i := 0; i < count; i++ {
  1030  		checkFuncs = append(checkFuncs, ioPayFuncs...)
  1031  	}
  1032  	checkFuncs = append(checkFuncs,
  1033  		checkOutTrailer,
  1034  		checkEnd,
  1035  	)
  1036  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs)
  1037  }
  1038  
  1039  func (s) TestServerStatsFullDuplexRPCError(t *testing.T) {
  1040  	count := 5
  1041  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1042  		checkInHeader,
  1043  		checkBegin,
  1044  		checkOutHeader,
  1045  		checkInPayload,
  1046  		checkOutTrailer,
  1047  		checkEnd,
  1048  	})
  1049  }
  1050  
  1051  type checkFuncWithCount struct {
  1052  	f func(t *testing.T, d *gotData, e *expectedData)
  1053  	c int // expected count
  1054  }
  1055  
  1056  func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
  1057  	var expectLen int
  1058  	for _, v := range checkFuncs {
  1059  		expectLen += v.c
  1060  	}
  1061  	if len(got) != expectLen {
  1062  		for i, g := range got {
  1063  			t.Errorf(" - %v, %T", i, g.s)
  1064  		}
  1065  		t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
  1066  	}
  1067  
  1068  	var tagInfoInCtx *stats.RPCTagInfo
  1069  	for i := 0; i < len(got); i++ {
  1070  		if _, ok := got[i].s.(stats.RPCStats); ok {
  1071  			tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
  1072  			if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
  1073  				t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
  1074  			}
  1075  			tagInfoInCtx = tagInfoInCtxNew
  1076  		}
  1077  	}
  1078  
  1079  	for _, s := range got {
  1080  		switch s.s.(type) {
  1081  		case *stats.Begin:
  1082  			if checkFuncs[begin].c <= 0 {
  1083  				t.Fatalf("unexpected stats: %T", s.s)
  1084  			}
  1085  			checkFuncs[begin].f(t, s, expect)
  1086  			checkFuncs[begin].c--
  1087  		case *stats.OutHeader:
  1088  			if checkFuncs[outHeader].c <= 0 {
  1089  				t.Fatalf("unexpected stats: %T", s.s)
  1090  			}
  1091  			checkFuncs[outHeader].f(t, s, expect)
  1092  			checkFuncs[outHeader].c--
  1093  		case *stats.OutPayload:
  1094  			if checkFuncs[outPayload].c <= 0 {
  1095  				t.Fatalf("unexpected stats: %T", s.s)
  1096  			}
  1097  			checkFuncs[outPayload].f(t, s, expect)
  1098  			checkFuncs[outPayload].c--
  1099  		case *stats.InHeader:
  1100  			if checkFuncs[inHeader].c <= 0 {
  1101  				t.Fatalf("unexpected stats: %T", s.s)
  1102  			}
  1103  			checkFuncs[inHeader].f(t, s, expect)
  1104  			checkFuncs[inHeader].c--
  1105  		case *stats.InPayload:
  1106  			if checkFuncs[inPayload].c <= 0 {
  1107  				t.Fatalf("unexpected stats: %T", s.s)
  1108  			}
  1109  			checkFuncs[inPayload].f(t, s, expect)
  1110  			checkFuncs[inPayload].c--
  1111  		case *stats.InTrailer:
  1112  			if checkFuncs[inTrailer].c <= 0 {
  1113  				t.Fatalf("unexpected stats: %T", s.s)
  1114  			}
  1115  			checkFuncs[inTrailer].f(t, s, expect)
  1116  			checkFuncs[inTrailer].c--
  1117  		case *stats.End:
  1118  			if checkFuncs[end].c <= 0 {
  1119  				t.Fatalf("unexpected stats: %T", s.s)
  1120  			}
  1121  			checkFuncs[end].f(t, s, expect)
  1122  			checkFuncs[end].c--
  1123  		case *stats.ConnBegin:
  1124  			if checkFuncs[connBegin].c <= 0 {
  1125  				t.Fatalf("unexpected stats: %T", s.s)
  1126  			}
  1127  			checkFuncs[connBegin].f(t, s, expect)
  1128  			checkFuncs[connBegin].c--
  1129  		case *stats.ConnEnd:
  1130  			if checkFuncs[connEnd].c <= 0 {
  1131  				t.Fatalf("unexpected stats: %T", s.s)
  1132  			}
  1133  			checkFuncs[connEnd].f(t, s, expect)
  1134  			checkFuncs[connEnd].c--
  1135  		default:
  1136  			t.Fatalf("unexpected stats: %T", s.s)
  1137  		}
  1138  	}
  1139  }
  1140  
  1141  func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
  1142  	h := &statshandler{}
  1143  	te := newTest(t, tc, []stats.Handler{h}, nil)
  1144  	te.startServer(&testServer{})
  1145  	defer te.tearDown()
  1146  
  1147  	var (
  1148  		reqs   []proto.Message
  1149  		resps  []proto.Message
  1150  		method string
  1151  		err    error
  1152  
  1153  		isClientStream bool
  1154  		isServerStream bool
  1155  
  1156  		req  proto.Message
  1157  		resp proto.Message
  1158  		e    error
  1159  	)
  1160  	switch cc.callType {
  1161  	case unaryRPC:
  1162  		method = "/grpc.testing.TestService/UnaryCall"
  1163  		req, resp, e = te.doUnaryCall(cc)
  1164  		reqs = []proto.Message{req}
  1165  		resps = []proto.Message{resp}
  1166  		err = e
  1167  	case clientStreamRPC:
  1168  		method = "/grpc.testing.TestService/StreamingInputCall"
  1169  		reqs, resp, e = te.doClientStreamCall(cc)
  1170  		resps = []proto.Message{resp}
  1171  		err = e
  1172  		isClientStream = true
  1173  	case serverStreamRPC:
  1174  		method = "/grpc.testing.TestService/StreamingOutputCall"
  1175  		req, resps, e = te.doServerStreamCall(cc)
  1176  		reqs = []proto.Message{req}
  1177  		err = e
  1178  		isServerStream = true
  1179  	case fullDuplexStreamRPC:
  1180  		method = "/grpc.testing.TestService/FullDuplexCall"
  1181  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
  1182  		isClientStream = true
  1183  		isServerStream = true
  1184  	}
  1185  	if cc.success != (err == nil) {
  1186  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1187  	}
  1188  	te.cc.Close()
  1189  	te.srv.GracefulStop() // Wait for the server to stop.
  1190  
  1191  	lenRPCStats := 0
  1192  	for _, v := range checkFuncs {
  1193  		lenRPCStats += v.c
  1194  	}
  1195  	for {
  1196  		h.mu.Lock()
  1197  		if len(h.gotRPC) >= lenRPCStats {
  1198  			h.mu.Unlock()
  1199  			break
  1200  		}
  1201  		h.mu.Unlock()
  1202  		time.Sleep(10 * time.Millisecond)
  1203  	}
  1204  
  1205  	for {
  1206  		h.mu.Lock()
  1207  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
  1208  			h.mu.Unlock()
  1209  			break
  1210  		}
  1211  		h.mu.Unlock()
  1212  		time.Sleep(10 * time.Millisecond)
  1213  	}
  1214  
  1215  	expect := &expectedData{
  1216  		serverAddr:     te.srvAddr,
  1217  		compression:    tc.compress,
  1218  		method:         method,
  1219  		requests:       reqs,
  1220  		responses:      resps,
  1221  		failfast:       cc.failfast,
  1222  		err:            err,
  1223  		isClientStream: isClientStream,
  1224  		isServerStream: isServerStream,
  1225  	}
  1226  
  1227  	h.mu.Lock()
  1228  	checkConnStats(t, h.gotConn)
  1229  	h.mu.Unlock()
  1230  	checkClientStats(t, h.gotRPC, expect, checkFuncs)
  1231  }
  1232  
  1233  func (s) TestClientStatsUnaryRPC(t *testing.T) {
  1234  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1235  		begin:      {checkBegin, 1},
  1236  		outHeader:  {checkOutHeader, 1},
  1237  		outPayload: {checkOutPayload, 1},
  1238  		inHeader:   {checkInHeader, 1},
  1239  		inPayload:  {checkInPayload, 1},
  1240  		inTrailer:  {checkInTrailer, 1},
  1241  		end:        {checkEnd, 1},
  1242  	})
  1243  }
  1244  
  1245  func (s) TestClientStatsUnaryRPCError(t *testing.T) {
  1246  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1247  		begin:      {checkBegin, 1},
  1248  		outHeader:  {checkOutHeader, 1},
  1249  		outPayload: {checkOutPayload, 1},
  1250  		inHeader:   {checkInHeader, 1},
  1251  		inTrailer:  {checkInTrailer, 1},
  1252  		end:        {checkEnd, 1},
  1253  	})
  1254  }
  1255  
  1256  func (s) TestClientStatsClientStreamRPC(t *testing.T) {
  1257  	count := 5
  1258  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1259  		begin:      {checkBegin, 1},
  1260  		outHeader:  {checkOutHeader, 1},
  1261  		inHeader:   {checkInHeader, 1},
  1262  		outPayload: {checkOutPayload, count},
  1263  		inTrailer:  {checkInTrailer, 1},
  1264  		inPayload:  {checkInPayload, 1},
  1265  		end:        {checkEnd, 1},
  1266  	})
  1267  }
  1268  
  1269  func (s) TestClientStatsClientStreamRPCError(t *testing.T) {
  1270  	count := 1
  1271  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1272  		begin:      {checkBegin, 1},
  1273  		outHeader:  {checkOutHeader, 1},
  1274  		inHeader:   {checkInHeader, 1},
  1275  		outPayload: {checkOutPayload, 1},
  1276  		inTrailer:  {checkInTrailer, 1},
  1277  		end:        {checkEnd, 1},
  1278  	})
  1279  }
  1280  
  1281  func (s) TestClientStatsServerStreamRPC(t *testing.T) {
  1282  	count := 5
  1283  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1284  		begin:      {checkBegin, 1},
  1285  		outHeader:  {checkOutHeader, 1},
  1286  		outPayload: {checkOutPayload, 1},
  1287  		inHeader:   {checkInHeader, 1},
  1288  		inPayload:  {checkInPayload, count},
  1289  		inTrailer:  {checkInTrailer, 1},
  1290  		end:        {checkEnd, 1},
  1291  	})
  1292  }
  1293  
  1294  func (s) TestClientStatsServerStreamRPCError(t *testing.T) {
  1295  	count := 5
  1296  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1297  		begin:      {checkBegin, 1},
  1298  		outHeader:  {checkOutHeader, 1},
  1299  		outPayload: {checkOutPayload, 1},
  1300  		inHeader:   {checkInHeader, 1},
  1301  		inTrailer:  {checkInTrailer, 1},
  1302  		end:        {checkEnd, 1},
  1303  	})
  1304  }
  1305  
  1306  func (s) TestClientStatsFullDuplexRPC(t *testing.T) {
  1307  	count := 5
  1308  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1309  		begin:      {checkBegin, 1},
  1310  		outHeader:  {checkOutHeader, 1},
  1311  		outPayload: {checkOutPayload, count},
  1312  		inHeader:   {checkInHeader, 1},
  1313  		inPayload:  {checkInPayload, count},
  1314  		inTrailer:  {checkInTrailer, 1},
  1315  		end:        {checkEnd, 1},
  1316  	})
  1317  }
  1318  
  1319  func (s) TestClientStatsFullDuplexRPCError(t *testing.T) {
  1320  	count := 5
  1321  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1322  		begin:      {checkBegin, 1},
  1323  		outHeader:  {checkOutHeader, 1},
  1324  		outPayload: {checkOutPayload, 1},
  1325  		inHeader:   {checkInHeader, 1},
  1326  		inTrailer:  {checkInTrailer, 1},
  1327  		end:        {checkEnd, 1},
  1328  	})
  1329  }
  1330  
  1331  func (s) TestTags(t *testing.T) {
  1332  	b := []byte{5, 2, 4, 3, 1}
  1333  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1334  	defer cancel()
  1335  	ctx := stats.SetTags(tCtx, b)
  1336  	if tg := getOutgoingStats(ctx, "grpc-tags-bin"); !reflect.DeepEqual(tg, b) {
  1337  		t.Errorf("getOutgoingStats(%v, grpc-tags-bin) = %v; want %v", ctx, tg, b)
  1338  	}
  1339  	if tg := stats.Tags(ctx); tg != nil {
  1340  		t.Errorf("Tags(%v) = %v; want nil", ctx, tg)
  1341  	}
  1342  
  1343  	ctx = setIncomingStats(tCtx, "grpc-tags-bin", b)
  1344  	if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) {
  1345  		t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b)
  1346  	}
  1347  	if tg := getOutgoingStats(ctx, "grpc-tags-bin"); tg != nil {
  1348  		t.Errorf("getOutgoingStats(%v, grpc-tags-bin) = %v; want nil", ctx, tg)
  1349  	}
  1350  }
  1351  
  1352  func (s) TestTrace(t *testing.T) {
  1353  	b := []byte{5, 2, 4, 3, 1}
  1354  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1355  	defer cancel()
  1356  	ctx := stats.SetTrace(tCtx, b)
  1357  	if tr := getOutgoingStats(ctx, "grpc-trace-bin"); !reflect.DeepEqual(tr, b) {
  1358  		t.Errorf("getOutgoingStats(%v, grpc-trace-bin) = %v; want %v", ctx, tr, b)
  1359  	}
  1360  	if tr := stats.Trace(ctx); tr != nil {
  1361  		t.Errorf("Trace(%v) = %v; want nil", ctx, tr)
  1362  	}
  1363  
  1364  	ctx = setIncomingStats(tCtx, "grpc-trace-bin", b)
  1365  	if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) {
  1366  		t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b)
  1367  	}
  1368  	if tr := getOutgoingStats(ctx, "grpc-trace-bin"); tr != nil {
  1369  		t.Errorf("getOutgoingStats(%v, grpc-trace-bin) = %v; want nil", ctx, tr)
  1370  	}
  1371  }
  1372  
  1373  func (s) TestMultipleClientStatsHandler(t *testing.T) {
  1374  	h := &statshandler{}
  1375  	tc := &testConfig{compress: ""}
  1376  	te := newTest(t, tc, []stats.Handler{h, h}, nil)
  1377  	te.startServer(&testServer{})
  1378  	defer te.tearDown()
  1379  
  1380  	cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC}
  1381  	_, _, err := te.doUnaryCall(cc)
  1382  	if cc.success != (err == nil) {
  1383  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1384  	}
  1385  	te.cc.Close()
  1386  	te.srv.GracefulStop() // Wait for the server to stop.
  1387  
  1388  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1389  		h.mu.Lock()
  1390  		if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok && len(h.gotRPC) == 12 {
  1391  			h.mu.Unlock()
  1392  			break
  1393  		}
  1394  		h.mu.Unlock()
  1395  		time.Sleep(10 * time.Millisecond)
  1396  	}
  1397  
  1398  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1399  		h.mu.Lock()
  1400  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok && len(h.gotConn) == 4 {
  1401  			h.mu.Unlock()
  1402  			break
  1403  		}
  1404  		h.mu.Unlock()
  1405  		time.Sleep(10 * time.Millisecond)
  1406  	}
  1407  
  1408  	// Each RPC generates 6 stats events on the client-side, times 2 StatsHandler
  1409  	if len(h.gotRPC) != 12 {
  1410  		t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12)
  1411  	}
  1412  
  1413  	// Each connection generates 4 conn events on the client-side, times 2 StatsHandler
  1414  	if len(h.gotConn) != 4 {
  1415  		t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
  1416  	}
  1417  }
  1418  
  1419  func (s) TestMultipleServerStatsHandler(t *testing.T) {
  1420  	h := &statshandler{}
  1421  	tc := &testConfig{compress: ""}
  1422  	te := newTest(t, tc, nil, []stats.Handler{h, h})
  1423  	te.startServer(&testServer{})
  1424  	defer te.tearDown()
  1425  
  1426  	cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC}
  1427  	_, _, err := te.doUnaryCall(cc)
  1428  	if cc.success != (err == nil) {
  1429  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1430  	}
  1431  	te.cc.Close()
  1432  	te.srv.GracefulStop() // Wait for the server to stop.
  1433  
  1434  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1435  		h.mu.Lock()
  1436  		if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok {
  1437  			h.mu.Unlock()
  1438  			break
  1439  		}
  1440  		h.mu.Unlock()
  1441  		time.Sleep(10 * time.Millisecond)
  1442  	}
  1443  
  1444  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1445  		h.mu.Lock()
  1446  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
  1447  			h.mu.Unlock()
  1448  			break
  1449  		}
  1450  		h.mu.Unlock()
  1451  		time.Sleep(10 * time.Millisecond)
  1452  	}
  1453  
  1454  	// Each RPC generates 6 stats events on the server-side, times 2 StatsHandler
  1455  	if len(h.gotRPC) != 12 {
  1456  		t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12)
  1457  	}
  1458  
  1459  	// Each connection generates 4 conn events on the server-side, times 2 StatsHandler
  1460  	if len(h.gotConn) != 4 {
  1461  		t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
  1462  	}
  1463  }
  1464  
  1465  // TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
  1466  // gets access to a Server on the server side, and thus the method that the
  1467  // server owns which specifies whether a method is made or not. The test sets up
  1468  // a server with a unary call and full duplex call configured, and makes an RPC.
  1469  // Within the stats handler, asking the server whether unary or duplex method
  1470  // names are registered should return true, and any other query should return
  1471  // false.
  1472  func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
  1473  	wg := sync.WaitGroup{}
  1474  	wg.Add(1)
  1475  	stubStatsHandler := &testutils.StubStatsHandler{
  1476  		TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
  1477  			// OpenTelemetry instrumentation needs the passed in Server to determine if
  1478  			// methods are registered in different handle calls in to record metrics.
  1479  			// This tag RPC call context gets passed into every handle call, so can
  1480  			// assert once here, since it maps to all the handle RPC calls that come
  1481  			// after. These internal calls will be how the OpenTelemetry instrumentation
  1482  			// component accesses this server and the subsequent helper on the server.
  1483  			server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx)
  1484  			if server == nil {
  1485  				t.Errorf("stats handler received ctx has no server present")
  1486  			}
  1487  			isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)
  1488  			// /s/m and s/m are valid.
  1489  			if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") {
  1490  				t.Errorf("UnaryCall should be a registered method according to server")
  1491  			}
  1492  			if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") {
  1493  				t.Errorf("FullDuplexCall should be a registered method according to server")
  1494  			}
  1495  			if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") {
  1496  				t.Errorf("DoesNotExistCall should not be a registered method according to server")
  1497  			}
  1498  			if isRegisteredMethod(server, "/unknownService/UnaryCall") {
  1499  				t.Errorf("/unknownService/UnaryCall should not be a registered method according to server")
  1500  			}
  1501  			wg.Done()
  1502  			return ctx
  1503  		},
  1504  	}
  1505  	ss := &stubserver.StubServer{
  1506  		UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
  1507  			return &testpb.SimpleResponse{}, nil
  1508  		},
  1509  	}
  1510  	if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil {
  1511  		t.Fatalf("Error starting endpoint server: %v", err)
  1512  	}
  1513  	defer ss.Stop()
  1514  
  1515  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1516  	defer cancel()
  1517  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
  1518  		t.Fatalf("Unexpected error from UnaryCall: %v", err)
  1519  	}
  1520  	wg.Wait()
  1521  }