github.com/shuguocloud/go-zero@v1.3.0/zrpc/internal/clientinterceptors/tracinginterceptor_test.go (about)

     1  package clientinterceptors
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/shuguocloud/go-zero/core/trace"
    13  	"google.golang.org/grpc"
    14  	"google.golang.org/grpc/codes"
    15  	"google.golang.org/grpc/metadata"
    16  	"google.golang.org/grpc/status"
    17  )
    18  
    19  func TestOpenTracingInterceptor(t *testing.T) {
    20  	trace.StartAgent(trace.Config{
    21  		Name:     "go-zero-test",
    22  		Endpoint: "http://localhost:14268/api/traces",
    23  		Batcher:  "jaeger",
    24  		Sampler:  1.0,
    25  	})
    26  
    27  	cc := new(grpc.ClientConn)
    28  	ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{})
    29  	err := UnaryTracingInterceptor(ctx, "/ListUser", nil, nil, cc,
    30  		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
    31  			opts ...grpc.CallOption) error {
    32  			return nil
    33  		})
    34  	assert.Nil(t, err)
    35  }
    36  
    37  func TestUnaryTracingInterceptor(t *testing.T) {
    38  	var run int32
    39  	var wg sync.WaitGroup
    40  	wg.Add(1)
    41  	cc := new(grpc.ClientConn)
    42  	err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
    43  		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
    44  			opts ...grpc.CallOption) error {
    45  			defer wg.Done()
    46  			atomic.AddInt32(&run, 1)
    47  			return nil
    48  		})
    49  	wg.Wait()
    50  	assert.Nil(t, err)
    51  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
    52  }
    53  
    54  func TestUnaryTracingInterceptor_WithError(t *testing.T) {
    55  	var run int32
    56  	var wg sync.WaitGroup
    57  	wg.Add(1)
    58  	cc := new(grpc.ClientConn)
    59  	err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
    60  		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
    61  			opts ...grpc.CallOption) error {
    62  			defer wg.Done()
    63  			atomic.AddInt32(&run, 1)
    64  			return errors.New("dummy")
    65  		})
    66  	wg.Wait()
    67  	assert.NotNil(t, err)
    68  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
    69  }
    70  
    71  func TestStreamTracingInterceptor(t *testing.T) {
    72  	var run int32
    73  	var wg sync.WaitGroup
    74  	wg.Add(1)
    75  	cc := new(grpc.ClientConn)
    76  	_, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
    77  		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
    78  			opts ...grpc.CallOption) (grpc.ClientStream, error) {
    79  			defer wg.Done()
    80  			atomic.AddInt32(&run, 1)
    81  			return nil, nil
    82  		})
    83  	wg.Wait()
    84  	assert.Nil(t, err)
    85  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
    86  }
    87  
    88  func TestStreamTracingInterceptor_FinishWithNormalError(t *testing.T) {
    89  	var wg sync.WaitGroup
    90  	wg.Add(1)
    91  	cc := new(grpc.ClientConn)
    92  	ctx, cancel := context.WithCancel(context.Background())
    93  	stream, err := StreamTracingInterceptor(ctx, nil, cc, "/foo",
    94  		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
    95  			opts ...grpc.CallOption) (grpc.ClientStream, error) {
    96  			defer wg.Done()
    97  			return nil, nil
    98  		})
    99  	wg.Wait()
   100  	assert.Nil(t, err)
   101  
   102  	cancel()
   103  	cs := stream.(*clientStream)
   104  	<-cs.eventsDone
   105  }
   106  
   107  func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
   108  	tests := []struct {
   109  		name  string
   110  		event streamEventType
   111  		err   error
   112  	}{
   113  		{
   114  			name:  "receive event",
   115  			event: receiveEndEvent,
   116  			err:   status.Error(codes.DataLoss, "dummy"),
   117  		},
   118  		{
   119  			name:  "error event",
   120  			event: errorEvent,
   121  			err:   status.Error(codes.DataLoss, "dummy"),
   122  		},
   123  	}
   124  
   125  	for _, test := range tests {
   126  		test := test
   127  		t.Run(test.name, func(t *testing.T) {
   128  			t.Parallel()
   129  
   130  			var wg sync.WaitGroup
   131  			wg.Add(1)
   132  			cc := new(grpc.ClientConn)
   133  			stream, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
   134  				func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
   135  					opts ...grpc.CallOption) (grpc.ClientStream, error) {
   136  					defer wg.Done()
   137  					return &mockedClientStream{
   138  						err: errors.New("dummy"),
   139  					}, nil
   140  				})
   141  			wg.Wait()
   142  			assert.Nil(t, err)
   143  
   144  			cs := stream.(*clientStream)
   145  			cs.sendStreamEvent(test.event, status.Error(codes.DataLoss, "dummy"))
   146  			<-cs.eventsDone
   147  			cs.sendStreamEvent(test.event, test.err)
   148  			assert.NotNil(t, cs.CloseSend())
   149  		})
   150  	}
   151  }
   152  
   153  func TestStreamTracingInterceptor_WithError(t *testing.T) {
   154  	tests := []struct {
   155  		name string
   156  		err  error
   157  	}{
   158  		{
   159  			name: "normal error",
   160  			err:  errors.New("dummy"),
   161  		},
   162  		{
   163  			name: "grpc error",
   164  			err:  status.Error(codes.DataLoss, "dummy"),
   165  		},
   166  	}
   167  
   168  	for _, test := range tests {
   169  		test := test
   170  		t.Run(test.name, func(t *testing.T) {
   171  			t.Parallel()
   172  
   173  			var run int32
   174  			var wg sync.WaitGroup
   175  			wg.Add(1)
   176  			cc := new(grpc.ClientConn)
   177  			_, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
   178  				func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
   179  					opts ...grpc.CallOption) (grpc.ClientStream, error) {
   180  					defer wg.Done()
   181  					atomic.AddInt32(&run, 1)
   182  					return new(mockedClientStream), test.err
   183  				})
   184  			wg.Wait()
   185  			assert.NotNil(t, err)
   186  			assert.Equal(t, int32(1), atomic.LoadInt32(&run))
   187  		})
   188  	}
   189  }
   190  
   191  func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
   192  	var run int32
   193  	var wg sync.WaitGroup
   194  	wg.Add(1)
   195  	cc := new(grpc.ClientConn)
   196  	err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc,
   197  		func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
   198  			opts ...grpc.CallOption) error {
   199  			defer wg.Done()
   200  			atomic.AddInt32(&run, 1)
   201  			return nil
   202  		})
   203  	wg.Wait()
   204  	assert.Nil(t, err)
   205  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
   206  }
   207  
   208  func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
   209  	var run int32
   210  	var wg sync.WaitGroup
   211  	wg.Add(1)
   212  	cc := new(grpc.ClientConn)
   213  	_, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo",
   214  		func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
   215  			opts ...grpc.CallOption) (grpc.ClientStream, error) {
   216  			defer wg.Done()
   217  			atomic.AddInt32(&run, 1)
   218  			return nil, nil
   219  		})
   220  	wg.Wait()
   221  	assert.Nil(t, err)
   222  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
   223  }
   224  
   225  func TestClientStream_RecvMsg(t *testing.T) {
   226  	tests := []struct {
   227  		name          string
   228  		serverStreams bool
   229  		err           error
   230  	}{
   231  		{
   232  			name: "nil error",
   233  		},
   234  		{
   235  			name: "EOF",
   236  			err:  io.EOF,
   237  		},
   238  		{
   239  			name: "dummy error",
   240  			err:  errors.New("dummy"),
   241  		},
   242  		{
   243  			name:          "server streams",
   244  			serverStreams: true,
   245  		},
   246  	}
   247  
   248  	for _, test := range tests {
   249  		test := test
   250  		t.Run(test.name, func(t *testing.T) {
   251  			t.Parallel()
   252  			desc := new(grpc.StreamDesc)
   253  			desc.ServerStreams = test.serverStreams
   254  			stream := wrapClientStream(context.Background(), &mockedClientStream{
   255  				md:  nil,
   256  				err: test.err,
   257  			}, desc)
   258  			assert.Equal(t, test.err, stream.RecvMsg(nil))
   259  		})
   260  	}
   261  }
   262  
   263  func TestClientStream_Header(t *testing.T) {
   264  	tests := []struct {
   265  		name string
   266  		err  error
   267  	}{
   268  		{
   269  			name: "nil error",
   270  		},
   271  		{
   272  			name: "with error",
   273  			err:  errors.New("dummy"),
   274  		},
   275  	}
   276  
   277  	for _, test := range tests {
   278  		test := test
   279  		t.Run(test.name, func(t *testing.T) {
   280  			t.Parallel()
   281  			desc := new(grpc.StreamDesc)
   282  			stream := wrapClientStream(context.Background(), &mockedClientStream{
   283  				md:  metadata.MD{},
   284  				err: test.err,
   285  			}, desc)
   286  			_, err := stream.Header()
   287  			assert.Equal(t, test.err, err)
   288  		})
   289  	}
   290  }
   291  
   292  func TestClientStream_SendMsg(t *testing.T) {
   293  	tests := []struct {
   294  		name string
   295  		err  error
   296  	}{
   297  		{
   298  			name: "nil error",
   299  		},
   300  		{
   301  			name: "with error",
   302  			err:  errors.New("dummy"),
   303  		},
   304  	}
   305  
   306  	for _, test := range tests {
   307  		test := test
   308  		t.Run(test.name, func(t *testing.T) {
   309  			t.Parallel()
   310  			desc := new(grpc.StreamDesc)
   311  			stream := wrapClientStream(context.Background(), &mockedClientStream{
   312  				md:  metadata.MD{},
   313  				err: test.err,
   314  			}, desc)
   315  			assert.Equal(t, test.err, stream.SendMsg(nil))
   316  		})
   317  	}
   318  }
   319  
   320  type mockedClientStream struct {
   321  	md  metadata.MD
   322  	err error
   323  }
   324  
   325  func (m *mockedClientStream) Header() (metadata.MD, error) {
   326  	return m.md, m.err
   327  }
   328  
   329  func (m *mockedClientStream) Trailer() metadata.MD {
   330  	panic("implement me")
   331  }
   332  
   333  func (m *mockedClientStream) CloseSend() error {
   334  	return m.err
   335  }
   336  
   337  func (m *mockedClientStream) Context() context.Context {
   338  	return context.Background()
   339  }
   340  
   341  func (m *mockedClientStream) SendMsg(v interface{}) error {
   342  	return m.err
   343  }
   344  
   345  func (m *mockedClientStream) RecvMsg(v interface{}) error {
   346  	return m.err
   347  }