gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/stats/stats_test.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package stats_test
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"reflect"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	grpc "gitee.com/ks-custle/core-gm/grpc"
    32  	"gitee.com/ks-custle/core-gm/grpc/internal/grpctest"
    33  	"gitee.com/ks-custle/core-gm/grpc/metadata"
    34  	"gitee.com/ks-custle/core-gm/grpc/stats"
    35  	"gitee.com/ks-custle/core-gm/grpc/status"
    36  	"github.com/golang/protobuf/proto"
    37  
    38  	testgrpc "gitee.com/ks-custle/core-gm/grpc/interop/grpc_testing"
    39  	testpb "gitee.com/ks-custle/core-gm/grpc/interop/grpc_testing"
    40  )
    41  
    42  const defaultTestTimeout = 10 * time.Second
    43  
    44  type s struct {
    45  	grpctest.Tester
    46  }
    47  
    48  func Test(t *testing.T) {
    49  	grpctest.RunSubTests(t, s{})
    50  }
    51  
    52  func init() {
    53  	grpc.EnableTracing = false
    54  }
    55  
    56  type connCtxKey struct{}
    57  type rpcCtxKey struct{}
    58  
    59  var (
    60  	// For headers sent to server:
    61  	testMetadata = metadata.MD{
    62  		"key1":       []string{"value1"},
    63  		"key2":       []string{"value2"},
    64  		"user-agent": []string{fmt.Sprintf("test/0.0.1 grpc-go/%s", grpc.Version)},
    65  	}
    66  	// For headers sent from server:
    67  	testHeaderMetadata = metadata.MD{
    68  		"hkey1": []string{"headerValue1"},
    69  		"hkey2": []string{"headerValue2"},
    70  	}
    71  	// For trailers sent from server:
    72  	testTrailerMetadata = metadata.MD{
    73  		"tkey1": []string{"trailerValue1"},
    74  		"tkey2": []string{"trailerValue2"},
    75  	}
    76  	// The id for which the service handler should return error.
    77  	errorID int32 = 32202
    78  )
    79  
    80  func idToPayload(id int32) *testpb.Payload {
    81  	return &testpb.Payload{Body: []byte{byte(id), byte(id >> 8), byte(id >> 16), byte(id >> 24)}}
    82  }
    83  
    84  func payloadToID(p *testpb.Payload) int32 {
    85  	if p == nil || len(p.Body) != 4 {
    86  		panic("invalid payload")
    87  	}
    88  	return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24
    89  }
    90  
    91  type testServer struct {
    92  	testgrpc.UnimplementedTestServiceServer
    93  }
    94  
    95  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
    96  	if err := grpc.SendHeader(ctx, testHeaderMetadata); err != nil {
    97  		return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", testHeaderMetadata, err)
    98  	}
    99  	if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
   100  		return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
   101  	}
   102  
   103  	if id := payloadToID(in.Payload); id == errorID {
   104  		return nil, fmt.Errorf("got error id: %v", id)
   105  	}
   106  
   107  	return &testpb.SimpleResponse{Payload: in.Payload}, nil
   108  }
   109  
   110  func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
   111  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   112  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   113  	}
   114  	stream.SetTrailer(testTrailerMetadata)
   115  	for {
   116  		in, err := stream.Recv()
   117  		if err == io.EOF {
   118  			// read done.
   119  			return nil
   120  		}
   121  		if err != nil {
   122  			return err
   123  		}
   124  
   125  		if id := payloadToID(in.Payload); id == errorID {
   126  			return fmt.Errorf("got error id: %v", id)
   127  		}
   128  
   129  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   130  			return err
   131  		}
   132  	}
   133  }
   134  
   135  func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
   136  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   137  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   138  	}
   139  	stream.SetTrailer(testTrailerMetadata)
   140  	for {
   141  		in, err := stream.Recv()
   142  		if err == io.EOF {
   143  			// read done.
   144  			return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
   145  		}
   146  		if err != nil {
   147  			return err
   148  		}
   149  
   150  		if id := payloadToID(in.Payload); id == errorID {
   151  			return fmt.Errorf("got error id: %v", id)
   152  		}
   153  	}
   154  }
   155  
   156  func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
   157  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   158  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   159  	}
   160  	stream.SetTrailer(testTrailerMetadata)
   161  
   162  	if id := payloadToID(in.Payload); id == errorID {
   163  		return fmt.Errorf("got error id: %v", id)
   164  	}
   165  
   166  	for i := 0; i < 5; i++ {
   167  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   168  			return err
   169  		}
   170  	}
   171  	return nil
   172  }
   173  
   174  // test is an end-to-end test. It should be created with the newTest
   175  // func, modified as needed, and then started with its startServer method.
   176  // It should be cleaned up with the tearDown method.
   177  type test struct {
   178  	t                  *testing.T
   179  	compress           string
   180  	clientStatsHandler stats.Handler
   181  	serverStatsHandler stats.Handler
   182  
   183  	testServer testgrpc.TestServiceServer // nil means none
   184  	// srv and srvAddr are set once startServer is called.
   185  	srv     *grpc.Server
   186  	srvAddr string
   187  
   188  	cc *grpc.ClientConn // nil until requested via clientConn
   189  }
   190  
   191  func (te *test) tearDown() {
   192  	if te.cc != nil {
   193  		te.cc.Close()
   194  		te.cc = nil
   195  	}
   196  	te.srv.Stop()
   197  }
   198  
   199  type testConfig struct {
   200  	compress string
   201  }
   202  
   203  // newTest returns a new test using the provided testing.T and
   204  // environment.  It is returned with default values. Tests should
   205  // modify it before calling its startServer and clientConn methods.
   206  func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test {
   207  	te := &test{
   208  		t:                  t,
   209  		compress:           tc.compress,
   210  		clientStatsHandler: ch,
   211  		serverStatsHandler: sh,
   212  	}
   213  	return te
   214  }
   215  
   216  // startServer starts a gRPC server listening. Callers should defer a
   217  // call to te.tearDown to clean up.
   218  func (te *test) startServer(ts testgrpc.TestServiceServer) {
   219  	te.testServer = ts
   220  	lis, err := net.Listen("tcp", "localhost:0")
   221  	if err != nil {
   222  		te.t.Fatalf("Failed to listen: %v", err)
   223  	}
   224  	var opts []grpc.ServerOption
   225  	if te.compress == "gzip" {
   226  		opts = append(opts,
   227  			grpc.RPCCompressor(grpc.NewGZIPCompressor()),
   228  			grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
   229  		)
   230  	}
   231  	if te.serverStatsHandler != nil {
   232  		opts = append(opts, grpc.StatsHandler(te.serverStatsHandler))
   233  	}
   234  	s := grpc.NewServer(opts...)
   235  	te.srv = s
   236  	if te.testServer != nil {
   237  		testgrpc.RegisterTestServiceServer(s, te.testServer)
   238  	}
   239  
   240  	go s.Serve(lis)
   241  	te.srvAddr = lis.Addr().String()
   242  }
   243  
   244  func (te *test) clientConn() *grpc.ClientConn {
   245  	if te.cc != nil {
   246  		return te.cc
   247  	}
   248  	opts := []grpc.DialOption{
   249  		grpc.WithInsecure(),
   250  		grpc.WithBlock(),
   251  		grpc.WithUserAgent("test/0.0.1"),
   252  	}
   253  	if te.compress == "gzip" {
   254  		opts = append(opts,
   255  			grpc.WithCompressor(grpc.NewGZIPCompressor()),
   256  			grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
   257  		)
   258  	}
   259  	if te.clientStatsHandler != nil {
   260  		opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler))
   261  	}
   262  
   263  	var err error
   264  	te.cc, err = grpc.Dial(te.srvAddr, opts...)
   265  	if err != nil {
   266  		te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
   267  	}
   268  	return te.cc
   269  }
   270  
   271  type rpcType int
   272  
   273  const (
   274  	unaryRPC rpcType = iota
   275  	clientStreamRPC
   276  	serverStreamRPC
   277  	fullDuplexStreamRPC
   278  )
   279  
   280  type rpcConfig struct {
   281  	count    int  // Number of requests and responses for streaming RPCs.
   282  	success  bool // Whether the RPC should succeed or return error.
   283  	failfast bool
   284  	callType rpcType // Type of RPC.
   285  }
   286  
   287  func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
   288  	var (
   289  		resp *testpb.SimpleResponse
   290  		req  *testpb.SimpleRequest
   291  		err  error
   292  	)
   293  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   294  	if c.success {
   295  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)}
   296  	} else {
   297  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID)}
   298  	}
   299  
   300  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   301  	defer cancel()
   302  	resp, err = tc.UnaryCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   303  	return req, resp, err
   304  }
   305  
   306  func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []proto.Message, error) {
   307  	var (
   308  		reqs  []proto.Message
   309  		resps []proto.Message
   310  		err   error
   311  	)
   312  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   313  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   314  	defer cancel()
   315  	stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   316  	if err != nil {
   317  		return reqs, resps, err
   318  	}
   319  	var startID int32
   320  	if !c.success {
   321  		startID = errorID
   322  	}
   323  	for i := 0; i < c.count; i++ {
   324  		req := &testpb.StreamingOutputCallRequest{
   325  			Payload: idToPayload(int32(i) + startID),
   326  		}
   327  		reqs = append(reqs, req)
   328  		if err = stream.Send(req); err != nil {
   329  			return reqs, resps, err
   330  		}
   331  		var resp *testpb.StreamingOutputCallResponse
   332  		if resp, err = stream.Recv(); err != nil {
   333  			return reqs, resps, err
   334  		}
   335  		resps = append(resps, resp)
   336  	}
   337  	if err = stream.CloseSend(); err != nil && err != io.EOF {
   338  		return reqs, resps, err
   339  	}
   340  	if _, err = stream.Recv(); err != io.EOF {
   341  		return reqs, resps, err
   342  	}
   343  
   344  	return reqs, resps, nil
   345  }
   346  
   347  func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.StreamingInputCallResponse, error) {
   348  	var (
   349  		reqs []proto.Message
   350  		resp *testpb.StreamingInputCallResponse
   351  		err  error
   352  	)
   353  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   354  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   355  	defer cancel()
   356  	stream, err := tc.StreamingInputCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   357  	if err != nil {
   358  		return reqs, resp, err
   359  	}
   360  	var startID int32
   361  	if !c.success {
   362  		startID = errorID
   363  	}
   364  	for i := 0; i < c.count; i++ {
   365  		req := &testpb.StreamingInputCallRequest{
   366  			Payload: idToPayload(int32(i) + startID),
   367  		}
   368  		reqs = append(reqs, req)
   369  		if err = stream.Send(req); err != nil {
   370  			return reqs, resp, err
   371  		}
   372  	}
   373  	resp, err = stream.CloseAndRecv()
   374  	return reqs, resp, err
   375  }
   376  
   377  func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallRequest, []proto.Message, error) {
   378  	var (
   379  		req   *testpb.StreamingOutputCallRequest
   380  		resps []proto.Message
   381  		err   error
   382  	)
   383  
   384  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   385  
   386  	var startID int32
   387  	if !c.success {
   388  		startID = errorID
   389  	}
   390  	req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)}
   391  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   392  	defer cancel()
   393  	stream, err := tc.StreamingOutputCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   394  	if err != nil {
   395  		return req, resps, err
   396  	}
   397  	for {
   398  		var resp *testpb.StreamingOutputCallResponse
   399  		resp, err := stream.Recv()
   400  		if err == io.EOF {
   401  			return req, resps, nil
   402  		} else if err != nil {
   403  			return req, resps, err
   404  		}
   405  		resps = append(resps, resp)
   406  	}
   407  }
   408  
   409  type expectedData struct {
   410  	method         string
   411  	isClientStream bool
   412  	isServerStream bool
   413  	serverAddr     string
   414  	compression    string
   415  	reqIdx         int
   416  	requests       []proto.Message
   417  	respIdx        int
   418  	responses      []proto.Message
   419  	err            error
   420  	failfast       bool
   421  }
   422  
   423  type gotData struct {
   424  	ctx    context.Context
   425  	client bool
   426  	s      interface{} // This could be RPCStats or ConnStats.
   427  }
   428  
   429  const (
   430  	begin int = iota
   431  	end
   432  	inPayload
   433  	inHeader
   434  	inTrailer
   435  	outPayload
   436  	outHeader
   437  	// TODO: test outTrailer ?
   438  	connBegin
   439  	connEnd
   440  )
   441  
   442  func checkBegin(t *testing.T, d *gotData, e *expectedData) {
   443  	var (
   444  		ok bool
   445  		st *stats.Begin
   446  	)
   447  	if st, ok = d.s.(*stats.Begin); !ok {
   448  		t.Fatalf("got %T, want Begin", d.s)
   449  	}
   450  	if d.ctx == nil {
   451  		t.Fatalf("d.ctx = nil, want <non-nil>")
   452  	}
   453  	if st.BeginTime.IsZero() {
   454  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   455  	}
   456  	if d.client {
   457  		if st.FailFast != e.failfast {
   458  			t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
   459  		}
   460  	}
   461  	if st.IsClientStream != e.isClientStream {
   462  		t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream)
   463  	}
   464  	if st.IsServerStream != e.isServerStream {
   465  		t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream)
   466  	}
   467  }
   468  
   469  func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
   470  	var (
   471  		ok bool
   472  		st *stats.InHeader
   473  	)
   474  	if st, ok = d.s.(*stats.InHeader); !ok {
   475  		t.Fatalf("got %T, want InHeader", d.s)
   476  	}
   477  	if d.ctx == nil {
   478  		t.Fatalf("d.ctx = nil, want <non-nil>")
   479  	}
   480  	if st.Compression != e.compression {
   481  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   482  	}
   483  	if d.client {
   484  		// additional headers might be injected so instead of testing equality, test that all the
   485  		// expected headers keys have the expected header values.
   486  		for key := range testHeaderMetadata {
   487  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   488  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   489  			}
   490  		}
   491  	} else {
   492  		if st.FullMethod != e.method {
   493  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   494  		}
   495  		if st.LocalAddr.String() != e.serverAddr {
   496  			t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
   497  		}
   498  		// additional headers might be injected so instead of testing equality, test that all the
   499  		// expected headers keys have the expected header values.
   500  		for key := range testMetadata {
   501  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   502  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   503  			}
   504  		}
   505  
   506  		if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
   507  			if connInfo.RemoteAddr != st.RemoteAddr {
   508  				t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
   509  			}
   510  			if connInfo.LocalAddr != st.LocalAddr {
   511  				t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
   512  			}
   513  		} else {
   514  			t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
   515  		}
   516  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   517  			if rpcInfo.FullMethodName != st.FullMethod {
   518  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   519  			}
   520  		} else {
   521  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   522  		}
   523  	}
   524  }
   525  
   526  func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
   527  	var (
   528  		ok bool
   529  		st *stats.InPayload
   530  	)
   531  	if st, ok = d.s.(*stats.InPayload); !ok {
   532  		t.Fatalf("got %T, want InPayload", d.s)
   533  	}
   534  	if d.ctx == nil {
   535  		t.Fatalf("d.ctx = nil, want <non-nil>")
   536  	}
   537  	if d.client {
   538  		b, err := proto.Marshal(e.responses[e.respIdx])
   539  		if err != nil {
   540  			t.Fatalf("failed to marshal message: %v", err)
   541  		}
   542  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
   543  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
   544  		}
   545  		e.respIdx++
   546  		if string(st.Data) != string(b) {
   547  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   548  		}
   549  		if st.Length != len(b) {
   550  			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
   551  		}
   552  	} else {
   553  		b, err := proto.Marshal(e.requests[e.reqIdx])
   554  		if err != nil {
   555  			t.Fatalf("failed to marshal message: %v", err)
   556  		}
   557  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
   558  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
   559  		}
   560  		e.reqIdx++
   561  		if string(st.Data) != string(b) {
   562  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   563  		}
   564  		if st.Length != len(b) {
   565  			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
   566  		}
   567  	}
   568  	// Below are sanity checks that WireLength and RecvTime are populated.
   569  	// TODO: check values of WireLength and RecvTime.
   570  	if len(st.Data) > 0 && st.WireLength == 0 {
   571  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   572  			st.WireLength)
   573  	}
   574  	if st.RecvTime.IsZero() {
   575  		t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
   576  	}
   577  }
   578  
   579  func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
   580  	var (
   581  		ok bool
   582  		st *stats.InTrailer
   583  	)
   584  	if st, ok = d.s.(*stats.InTrailer); !ok {
   585  		t.Fatalf("got %T, want InTrailer", d.s)
   586  	}
   587  	if d.ctx == nil {
   588  		t.Fatalf("d.ctx = nil, want <non-nil>")
   589  	}
   590  	if !st.Client {
   591  		t.Fatalf("st IsClient = false, want true")
   592  	}
   593  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   594  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   595  	}
   596  }
   597  
   598  func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
   599  	var (
   600  		ok bool
   601  		st *stats.OutHeader
   602  	)
   603  	if st, ok = d.s.(*stats.OutHeader); !ok {
   604  		t.Fatalf("got %T, want OutHeader", d.s)
   605  	}
   606  	if d.ctx == nil {
   607  		t.Fatalf("d.ctx = nil, want <non-nil>")
   608  	}
   609  	if st.Compression != e.compression {
   610  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   611  	}
   612  	if d.client {
   613  		if st.FullMethod != e.method {
   614  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   615  		}
   616  		if st.RemoteAddr.String() != e.serverAddr {
   617  			t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
   618  		}
   619  		// additional headers might be injected so instead of testing equality, test that all the
   620  		// expected headers keys have the expected header values.
   621  		for key := range testMetadata {
   622  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   623  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   624  			}
   625  		}
   626  
   627  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   628  			if rpcInfo.FullMethodName != st.FullMethod {
   629  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   630  			}
   631  		} else {
   632  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   633  		}
   634  	} else {
   635  		// additional headers might be injected so instead of testing equality, test that all the
   636  		// expected headers keys have the expected header values.
   637  		for key := range testHeaderMetadata {
   638  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   639  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   640  			}
   641  		}
   642  	}
   643  }
   644  
   645  func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
   646  	var (
   647  		ok bool
   648  		st *stats.OutPayload
   649  	)
   650  	if st, ok = d.s.(*stats.OutPayload); !ok {
   651  		t.Fatalf("got %T, want OutPayload", d.s)
   652  	}
   653  	if d.ctx == nil {
   654  		t.Fatalf("d.ctx = nil, want <non-nil>")
   655  	}
   656  	if d.client {
   657  		b, err := proto.Marshal(e.requests[e.reqIdx])
   658  		if err != nil {
   659  			t.Fatalf("failed to marshal message: %v", err)
   660  		}
   661  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
   662  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
   663  		}
   664  		e.reqIdx++
   665  		if string(st.Data) != string(b) {
   666  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   667  		}
   668  		if st.Length != len(b) {
   669  			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
   670  		}
   671  	} else {
   672  		b, err := proto.Marshal(e.responses[e.respIdx])
   673  		if err != nil {
   674  			t.Fatalf("failed to marshal message: %v", err)
   675  		}
   676  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
   677  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
   678  		}
   679  		e.respIdx++
   680  		if string(st.Data) != string(b) {
   681  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   682  		}
   683  		if st.Length != len(b) {
   684  			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
   685  		}
   686  	}
   687  	// Below are sanity checks that WireLength and SentTime are populated.
   688  	// TODO: check values of WireLength and SentTime.
   689  	if len(st.Data) > 0 && st.WireLength == 0 {
   690  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   691  			st.WireLength)
   692  	}
   693  	if st.SentTime.IsZero() {
   694  		t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
   695  	}
   696  }
   697  
   698  func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
   699  	var (
   700  		ok bool
   701  		st *stats.OutTrailer
   702  	)
   703  	if st, ok = d.s.(*stats.OutTrailer); !ok {
   704  		t.Fatalf("got %T, want OutTrailer", d.s)
   705  	}
   706  	if d.ctx == nil {
   707  		t.Fatalf("d.ctx = nil, want <non-nil>")
   708  	}
   709  	if st.Client {
   710  		t.Fatalf("st IsClient = true, want false")
   711  	}
   712  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   713  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   714  	}
   715  }
   716  
   717  func checkEnd(t *testing.T, d *gotData, e *expectedData) {
   718  	var (
   719  		ok bool
   720  		st *stats.End
   721  	)
   722  	if st, ok = d.s.(*stats.End); !ok {
   723  		t.Fatalf("got %T, want End", d.s)
   724  	}
   725  	if d.ctx == nil {
   726  		t.Fatalf("d.ctx = nil, want <non-nil>")
   727  	}
   728  	if st.BeginTime.IsZero() {
   729  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   730  	}
   731  	if st.EndTime.IsZero() {
   732  		t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
   733  	}
   734  
   735  	actual, ok := status.FromError(st.Error)
   736  	if !ok {
   737  		t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
   738  	}
   739  
   740  	expectedStatus, _ := status.FromError(e.err)
   741  	if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() {
   742  		t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
   743  	}
   744  
   745  	if st.Client {
   746  		if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   747  			t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   748  		}
   749  	} else {
   750  		if st.Trailer != nil {
   751  			t.Fatalf("st.Trailer = %v, want nil", st.Trailer)
   752  		}
   753  	}
   754  }
   755  
   756  func checkConnBegin(t *testing.T, d *gotData) {
   757  	var (
   758  		ok bool
   759  		st *stats.ConnBegin
   760  	)
   761  	if st, ok = d.s.(*stats.ConnBegin); !ok {
   762  		t.Fatalf("got %T, want ConnBegin", d.s)
   763  	}
   764  	if d.ctx == nil {
   765  		t.Fatalf("d.ctx = nil, want <non-nil>")
   766  	}
   767  	st.IsClient() // TODO remove this.
   768  }
   769  
   770  func checkConnEnd(t *testing.T, d *gotData) {
   771  	var (
   772  		ok bool
   773  		st *stats.ConnEnd
   774  	)
   775  	if st, ok = d.s.(*stats.ConnEnd); !ok {
   776  		t.Fatalf("got %T, want ConnEnd", d.s)
   777  	}
   778  	if d.ctx == nil {
   779  		t.Fatalf("d.ctx = nil, want <non-nil>")
   780  	}
   781  	st.IsClient() // TODO remove this.
   782  }
   783  
   784  type statshandler struct {
   785  	mu      sync.Mutex
   786  	gotRPC  []*gotData
   787  	gotConn []*gotData
   788  }
   789  
   790  func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
   791  	return context.WithValue(ctx, connCtxKey{}, info)
   792  }
   793  
   794  func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
   795  	return context.WithValue(ctx, rpcCtxKey{}, info)
   796  }
   797  
   798  func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
   799  	h.mu.Lock()
   800  	defer h.mu.Unlock()
   801  	h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
   802  }
   803  
   804  func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
   805  	h.mu.Lock()
   806  	defer h.mu.Unlock()
   807  	h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
   808  }
   809  
   810  func checkConnStats(t *testing.T, got []*gotData) {
   811  	if len(got) <= 0 || len(got)%2 != 0 {
   812  		for i, g := range got {
   813  			t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx)
   814  		}
   815  		t.Fatalf("got %v stats, want even positive number", len(got))
   816  	}
   817  	// The first conn stats must be a ConnBegin.
   818  	checkConnBegin(t, got[0])
   819  	// The last conn stats must be a ConnEnd.
   820  	checkConnEnd(t, got[len(got)-1])
   821  }
   822  
   823  func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   824  	if len(got) != len(checkFuncs) {
   825  		for i, g := range got {
   826  			t.Errorf(" - %v, %T", i, g.s)
   827  		}
   828  		t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
   829  	}
   830  
   831  	var rpcctx context.Context
   832  	for i := 0; i < len(got); i++ {
   833  		if _, ok := got[i].s.(stats.RPCStats); ok {
   834  			if rpcctx != nil && got[i].ctx != rpcctx {
   835  				t.Fatalf("got different contexts with stats %T", got[i].s)
   836  			}
   837  			rpcctx = got[i].ctx
   838  		}
   839  	}
   840  
   841  	for i, f := range checkFuncs {
   842  		f(t, got[i], expect)
   843  	}
   844  }
   845  
   846  func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   847  	h := &statshandler{}
   848  	te := newTest(t, tc, nil, h)
   849  	te.startServer(&testServer{})
   850  	defer te.tearDown()
   851  
   852  	var (
   853  		reqs   []proto.Message
   854  		resps  []proto.Message
   855  		err    error
   856  		method string
   857  
   858  		isClientStream bool
   859  		isServerStream bool
   860  
   861  		req  proto.Message
   862  		resp proto.Message
   863  		e    error
   864  	)
   865  
   866  	switch cc.callType {
   867  	case unaryRPC:
   868  		method = "/grpc.testing.TestService/UnaryCall"
   869  		req, resp, e = te.doUnaryCall(cc)
   870  		reqs = []proto.Message{req}
   871  		resps = []proto.Message{resp}
   872  		err = e
   873  	case clientStreamRPC:
   874  		method = "/grpc.testing.TestService/StreamingInputCall"
   875  		reqs, resp, e = te.doClientStreamCall(cc)
   876  		resps = []proto.Message{resp}
   877  		err = e
   878  		isClientStream = true
   879  	case serverStreamRPC:
   880  		method = "/grpc.testing.TestService/StreamingOutputCall"
   881  		req, resps, e = te.doServerStreamCall(cc)
   882  		reqs = []proto.Message{req}
   883  		err = e
   884  		isServerStream = true
   885  	case fullDuplexStreamRPC:
   886  		method = "/grpc.testing.TestService/FullDuplexCall"
   887  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
   888  		isClientStream = true
   889  		isServerStream = true
   890  	}
   891  	if cc.success != (err == nil) {
   892  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
   893  	}
   894  	te.cc.Close()
   895  	te.srv.GracefulStop() // Wait for the server to stop.
   896  
   897  	for {
   898  		h.mu.Lock()
   899  		if len(h.gotRPC) >= len(checkFuncs) {
   900  			h.mu.Unlock()
   901  			break
   902  		}
   903  		h.mu.Unlock()
   904  		time.Sleep(10 * time.Millisecond)
   905  	}
   906  
   907  	for {
   908  		h.mu.Lock()
   909  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
   910  			h.mu.Unlock()
   911  			break
   912  		}
   913  		h.mu.Unlock()
   914  		time.Sleep(10 * time.Millisecond)
   915  	}
   916  
   917  	expect := &expectedData{
   918  		serverAddr:     te.srvAddr,
   919  		compression:    tc.compress,
   920  		method:         method,
   921  		requests:       reqs,
   922  		responses:      resps,
   923  		err:            err,
   924  		isClientStream: isClientStream,
   925  		isServerStream: isServerStream,
   926  	}
   927  
   928  	h.mu.Lock()
   929  	checkConnStats(t, h.gotConn)
   930  	h.mu.Unlock()
   931  	checkServerStats(t, h.gotRPC, expect, checkFuncs)
   932  }
   933  
   934  func (s) TestServerStatsUnaryRPC(t *testing.T) {
   935  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   936  		checkInHeader,
   937  		checkBegin,
   938  		checkInPayload,
   939  		checkOutHeader,
   940  		checkOutPayload,
   941  		checkOutTrailer,
   942  		checkEnd,
   943  	})
   944  }
   945  
   946  func (s) TestServerStatsUnaryRPCError(t *testing.T) {
   947  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   948  		checkInHeader,
   949  		checkBegin,
   950  		checkInPayload,
   951  		checkOutHeader,
   952  		checkOutTrailer,
   953  		checkEnd,
   954  	})
   955  }
   956  
   957  func (s) TestServerStatsClientStreamRPC(t *testing.T) {
   958  	count := 5
   959  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   960  		checkInHeader,
   961  		checkBegin,
   962  		checkOutHeader,
   963  	}
   964  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   965  		checkInPayload,
   966  	}
   967  	for i := 0; i < count; i++ {
   968  		checkFuncs = append(checkFuncs, ioPayFuncs...)
   969  	}
   970  	checkFuncs = append(checkFuncs,
   971  		checkOutPayload,
   972  		checkOutTrailer,
   973  		checkEnd,
   974  	)
   975  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs)
   976  }
   977  
   978  func (s) TestServerStatsClientStreamRPCError(t *testing.T) {
   979  	count := 1
   980  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   981  		checkInHeader,
   982  		checkBegin,
   983  		checkOutHeader,
   984  		checkInPayload,
   985  		checkOutTrailer,
   986  		checkEnd,
   987  	})
   988  }
   989  
   990  func (s) TestServerStatsServerStreamRPC(t *testing.T) {
   991  	count := 5
   992  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   993  		checkInHeader,
   994  		checkBegin,
   995  		checkInPayload,
   996  		checkOutHeader,
   997  	}
   998  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   999  		checkOutPayload,
  1000  	}
  1001  	for i := 0; i < count; i++ {
  1002  		checkFuncs = append(checkFuncs, ioPayFuncs...)
  1003  	}
  1004  	checkFuncs = append(checkFuncs,
  1005  		checkOutTrailer,
  1006  		checkEnd,
  1007  	)
  1008  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs)
  1009  }
  1010  
  1011  func (s) TestServerStatsServerStreamRPCError(t *testing.T) {
  1012  	count := 5
  1013  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1014  		checkInHeader,
  1015  		checkBegin,
  1016  		checkInPayload,
  1017  		checkOutHeader,
  1018  		checkOutTrailer,
  1019  		checkEnd,
  1020  	})
  1021  }
  1022  
  1023  func (s) TestServerStatsFullDuplexRPC(t *testing.T) {
  1024  	count := 5
  1025  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1026  		checkInHeader,
  1027  		checkBegin,
  1028  		checkOutHeader,
  1029  	}
  1030  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1031  		checkInPayload,
  1032  		checkOutPayload,
  1033  	}
  1034  	for i := 0; i < count; i++ {
  1035  		checkFuncs = append(checkFuncs, ioPayFuncs...)
  1036  	}
  1037  	checkFuncs = append(checkFuncs,
  1038  		checkOutTrailer,
  1039  		checkEnd,
  1040  	)
  1041  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs)
  1042  }
  1043  
  1044  func (s) TestServerStatsFullDuplexRPCError(t *testing.T) {
  1045  	count := 5
  1046  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1047  		checkInHeader,
  1048  		checkBegin,
  1049  		checkOutHeader,
  1050  		checkInPayload,
  1051  		checkOutTrailer,
  1052  		checkEnd,
  1053  	})
  1054  }
  1055  
  1056  type checkFuncWithCount struct {
  1057  	f func(t *testing.T, d *gotData, e *expectedData)
  1058  	c int // expected count
  1059  }
  1060  
  1061  func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
  1062  	var expectLen int
  1063  	for _, v := range checkFuncs {
  1064  		expectLen += v.c
  1065  	}
  1066  	if len(got) != expectLen {
  1067  		for i, g := range got {
  1068  			t.Errorf(" - %v, %T", i, g.s)
  1069  		}
  1070  		t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
  1071  	}
  1072  
  1073  	var tagInfoInCtx *stats.RPCTagInfo
  1074  	for i := 0; i < len(got); i++ {
  1075  		if _, ok := got[i].s.(stats.RPCStats); ok {
  1076  			tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
  1077  			if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
  1078  				t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
  1079  			}
  1080  			tagInfoInCtx = tagInfoInCtxNew
  1081  		}
  1082  	}
  1083  
  1084  	for _, s := range got {
  1085  		switch s.s.(type) {
  1086  		case *stats.Begin:
  1087  			if checkFuncs[begin].c <= 0 {
  1088  				t.Fatalf("unexpected stats: %T", s.s)
  1089  			}
  1090  			checkFuncs[begin].f(t, s, expect)
  1091  			checkFuncs[begin].c--
  1092  		case *stats.OutHeader:
  1093  			if checkFuncs[outHeader].c <= 0 {
  1094  				t.Fatalf("unexpected stats: %T", s.s)
  1095  			}
  1096  			checkFuncs[outHeader].f(t, s, expect)
  1097  			checkFuncs[outHeader].c--
  1098  		case *stats.OutPayload:
  1099  			if checkFuncs[outPayload].c <= 0 {
  1100  				t.Fatalf("unexpected stats: %T", s.s)
  1101  			}
  1102  			checkFuncs[outPayload].f(t, s, expect)
  1103  			checkFuncs[outPayload].c--
  1104  		case *stats.InHeader:
  1105  			if checkFuncs[inHeader].c <= 0 {
  1106  				t.Fatalf("unexpected stats: %T", s.s)
  1107  			}
  1108  			checkFuncs[inHeader].f(t, s, expect)
  1109  			checkFuncs[inHeader].c--
  1110  		case *stats.InPayload:
  1111  			if checkFuncs[inPayload].c <= 0 {
  1112  				t.Fatalf("unexpected stats: %T", s.s)
  1113  			}
  1114  			checkFuncs[inPayload].f(t, s, expect)
  1115  			checkFuncs[inPayload].c--
  1116  		case *stats.InTrailer:
  1117  			if checkFuncs[inTrailer].c <= 0 {
  1118  				t.Fatalf("unexpected stats: %T", s.s)
  1119  			}
  1120  			checkFuncs[inTrailer].f(t, s, expect)
  1121  			checkFuncs[inTrailer].c--
  1122  		case *stats.End:
  1123  			if checkFuncs[end].c <= 0 {
  1124  				t.Fatalf("unexpected stats: %T", s.s)
  1125  			}
  1126  			checkFuncs[end].f(t, s, expect)
  1127  			checkFuncs[end].c--
  1128  		case *stats.ConnBegin:
  1129  			if checkFuncs[connBegin].c <= 0 {
  1130  				t.Fatalf("unexpected stats: %T", s.s)
  1131  			}
  1132  			checkFuncs[connBegin].f(t, s, expect)
  1133  			checkFuncs[connBegin].c--
  1134  		case *stats.ConnEnd:
  1135  			if checkFuncs[connEnd].c <= 0 {
  1136  				t.Fatalf("unexpected stats: %T", s.s)
  1137  			}
  1138  			checkFuncs[connEnd].f(t, s, expect)
  1139  			checkFuncs[connEnd].c--
  1140  		default:
  1141  			t.Fatalf("unexpected stats: %T", s.s)
  1142  		}
  1143  	}
  1144  }
  1145  
  1146  func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
  1147  	h := &statshandler{}
  1148  	te := newTest(t, tc, h, nil)
  1149  	te.startServer(&testServer{})
  1150  	defer te.tearDown()
  1151  
  1152  	var (
  1153  		reqs   []proto.Message
  1154  		resps  []proto.Message
  1155  		method string
  1156  		err    error
  1157  
  1158  		isClientStream bool
  1159  		isServerStream bool
  1160  
  1161  		req  proto.Message
  1162  		resp proto.Message
  1163  		e    error
  1164  	)
  1165  	switch cc.callType {
  1166  	case unaryRPC:
  1167  		method = "/grpc.testing.TestService/UnaryCall"
  1168  		req, resp, e = te.doUnaryCall(cc)
  1169  		reqs = []proto.Message{req}
  1170  		resps = []proto.Message{resp}
  1171  		err = e
  1172  	case clientStreamRPC:
  1173  		method = "/grpc.testing.TestService/StreamingInputCall"
  1174  		reqs, resp, e = te.doClientStreamCall(cc)
  1175  		resps = []proto.Message{resp}
  1176  		err = e
  1177  		isClientStream = true
  1178  	case serverStreamRPC:
  1179  		method = "/grpc.testing.TestService/StreamingOutputCall"
  1180  		req, resps, e = te.doServerStreamCall(cc)
  1181  		reqs = []proto.Message{req}
  1182  		err = e
  1183  		isServerStream = true
  1184  	case fullDuplexStreamRPC:
  1185  		method = "/grpc.testing.TestService/FullDuplexCall"
  1186  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
  1187  		isClientStream = true
  1188  		isServerStream = true
  1189  	}
  1190  	if cc.success != (err == nil) {
  1191  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1192  	}
  1193  	te.cc.Close()
  1194  	te.srv.GracefulStop() // Wait for the server to stop.
  1195  
  1196  	lenRPCStats := 0
  1197  	for _, v := range checkFuncs {
  1198  		lenRPCStats += v.c
  1199  	}
  1200  	for {
  1201  		h.mu.Lock()
  1202  		if len(h.gotRPC) >= lenRPCStats {
  1203  			h.mu.Unlock()
  1204  			break
  1205  		}
  1206  		h.mu.Unlock()
  1207  		time.Sleep(10 * time.Millisecond)
  1208  	}
  1209  
  1210  	for {
  1211  		h.mu.Lock()
  1212  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
  1213  			h.mu.Unlock()
  1214  			break
  1215  		}
  1216  		h.mu.Unlock()
  1217  		time.Sleep(10 * time.Millisecond)
  1218  	}
  1219  
  1220  	expect := &expectedData{
  1221  		serverAddr:     te.srvAddr,
  1222  		compression:    tc.compress,
  1223  		method:         method,
  1224  		requests:       reqs,
  1225  		responses:      resps,
  1226  		failfast:       cc.failfast,
  1227  		err:            err,
  1228  		isClientStream: isClientStream,
  1229  		isServerStream: isServerStream,
  1230  	}
  1231  
  1232  	h.mu.Lock()
  1233  	checkConnStats(t, h.gotConn)
  1234  	h.mu.Unlock()
  1235  	checkClientStats(t, h.gotRPC, expect, checkFuncs)
  1236  }
  1237  
  1238  func (s) TestClientStatsUnaryRPC(t *testing.T) {
  1239  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1240  		begin:      {checkBegin, 1},
  1241  		outHeader:  {checkOutHeader, 1},
  1242  		outPayload: {checkOutPayload, 1},
  1243  		inHeader:   {checkInHeader, 1},
  1244  		inPayload:  {checkInPayload, 1},
  1245  		inTrailer:  {checkInTrailer, 1},
  1246  		end:        {checkEnd, 1},
  1247  	})
  1248  }
  1249  
  1250  func (s) TestClientStatsUnaryRPCError(t *testing.T) {
  1251  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1252  		begin:      {checkBegin, 1},
  1253  		outHeader:  {checkOutHeader, 1},
  1254  		outPayload: {checkOutPayload, 1},
  1255  		inHeader:   {checkInHeader, 1},
  1256  		inTrailer:  {checkInTrailer, 1},
  1257  		end:        {checkEnd, 1},
  1258  	})
  1259  }
  1260  
  1261  func (s) TestClientStatsClientStreamRPC(t *testing.T) {
  1262  	count := 5
  1263  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1264  		begin:      {checkBegin, 1},
  1265  		outHeader:  {checkOutHeader, 1},
  1266  		inHeader:   {checkInHeader, 1},
  1267  		outPayload: {checkOutPayload, count},
  1268  		inTrailer:  {checkInTrailer, 1},
  1269  		inPayload:  {checkInPayload, 1},
  1270  		end:        {checkEnd, 1},
  1271  	})
  1272  }
  1273  
  1274  func (s) TestClientStatsClientStreamRPCError(t *testing.T) {
  1275  	count := 1
  1276  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1277  		begin:      {checkBegin, 1},
  1278  		outHeader:  {checkOutHeader, 1},
  1279  		inHeader:   {checkInHeader, 1},
  1280  		outPayload: {checkOutPayload, 1},
  1281  		inTrailer:  {checkInTrailer, 1},
  1282  		end:        {checkEnd, 1},
  1283  	})
  1284  }
  1285  
  1286  func (s) TestClientStatsServerStreamRPC(t *testing.T) {
  1287  	count := 5
  1288  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1289  		begin:      {checkBegin, 1},
  1290  		outHeader:  {checkOutHeader, 1},
  1291  		outPayload: {checkOutPayload, 1},
  1292  		inHeader:   {checkInHeader, 1},
  1293  		inPayload:  {checkInPayload, count},
  1294  		inTrailer:  {checkInTrailer, 1},
  1295  		end:        {checkEnd, 1},
  1296  	})
  1297  }
  1298  
  1299  func (s) TestClientStatsServerStreamRPCError(t *testing.T) {
  1300  	count := 5
  1301  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1302  		begin:      {checkBegin, 1},
  1303  		outHeader:  {checkOutHeader, 1},
  1304  		outPayload: {checkOutPayload, 1},
  1305  		inHeader:   {checkInHeader, 1},
  1306  		inTrailer:  {checkInTrailer, 1},
  1307  		end:        {checkEnd, 1},
  1308  	})
  1309  }
  1310  
  1311  func (s) TestClientStatsFullDuplexRPC(t *testing.T) {
  1312  	count := 5
  1313  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1314  		begin:      {checkBegin, 1},
  1315  		outHeader:  {checkOutHeader, 1},
  1316  		outPayload: {checkOutPayload, count},
  1317  		inHeader:   {checkInHeader, 1},
  1318  		inPayload:  {checkInPayload, count},
  1319  		inTrailer:  {checkInTrailer, 1},
  1320  		end:        {checkEnd, 1},
  1321  	})
  1322  }
  1323  
  1324  func (s) TestClientStatsFullDuplexRPCError(t *testing.T) {
  1325  	count := 5
  1326  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1327  		begin:      {checkBegin, 1},
  1328  		outHeader:  {checkOutHeader, 1},
  1329  		outPayload: {checkOutPayload, 1},
  1330  		inHeader:   {checkInHeader, 1},
  1331  		inTrailer:  {checkInTrailer, 1},
  1332  		end:        {checkEnd, 1},
  1333  	})
  1334  }
  1335  
  1336  func (s) TestTags(t *testing.T) {
  1337  	b := []byte{5, 2, 4, 3, 1}
  1338  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1339  	defer cancel()
  1340  	ctx := stats.SetTags(tCtx, b)
  1341  	if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) {
  1342  		t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b)
  1343  	}
  1344  	if tg := stats.Tags(ctx); tg != nil {
  1345  		t.Errorf("Tags(%v) = %v; want nil", ctx, tg)
  1346  	}
  1347  
  1348  	ctx = stats.SetIncomingTags(tCtx, b)
  1349  	if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) {
  1350  		t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b)
  1351  	}
  1352  	if tg := stats.OutgoingTags(ctx); tg != nil {
  1353  		t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg)
  1354  	}
  1355  }
  1356  
  1357  func (s) TestTrace(t *testing.T) {
  1358  	b := []byte{5, 2, 4, 3, 1}
  1359  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1360  	defer cancel()
  1361  	ctx := stats.SetTrace(tCtx, b)
  1362  	if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) {
  1363  		t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b)
  1364  	}
  1365  	if tr := stats.Trace(ctx); tr != nil {
  1366  		t.Errorf("Trace(%v) = %v; want nil", ctx, tr)
  1367  	}
  1368  
  1369  	ctx = stats.SetIncomingTrace(tCtx, b)
  1370  	if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) {
  1371  		t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b)
  1372  	}
  1373  	if tr := stats.OutgoingTrace(ctx); tr != nil {
  1374  		t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr)
  1375  	}
  1376  }