github.com/shuguocloud/go-zero@v1.3.0/zrpc/internal/serverinterceptors/tracinginterceptor_test.go (about)

     1  package serverinterceptors
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/shuguocloud/go-zero/core/trace"
    13  	"google.golang.org/grpc"
    14  	"google.golang.org/grpc/codes"
    15  	"google.golang.org/grpc/metadata"
    16  	"google.golang.org/grpc/status"
    17  )
    18  
    19  func TestUnaryOpenTracingInterceptor_Disable(t *testing.T) {
    20  	_, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    21  		FullMethod: "/",
    22  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    23  		return nil, nil
    24  	})
    25  	assert.Nil(t, err)
    26  }
    27  
    28  func TestUnaryOpenTracingInterceptor_Enabled(t *testing.T) {
    29  	trace.StartAgent(trace.Config{
    30  		Name:     "go-zero-test",
    31  		Endpoint: "http://localhost:14268/api/traces",
    32  		Batcher:  "jaeger",
    33  		Sampler:  1.0,
    34  	})
    35  	_, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    36  		FullMethod: "/package.TestService.GetUser",
    37  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    38  		return nil, nil
    39  	})
    40  	assert.Nil(t, err)
    41  }
    42  
    43  func TestUnaryTracingInterceptor(t *testing.T) {
    44  	var run int32
    45  	var wg sync.WaitGroup
    46  	wg.Add(1)
    47  	_, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    48  		FullMethod: "/",
    49  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    50  		defer wg.Done()
    51  		atomic.AddInt32(&run, 1)
    52  		return nil, nil
    53  	})
    54  	wg.Wait()
    55  	assert.Nil(t, err)
    56  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
    57  }
    58  
    59  func TestUnaryTracingInterceptor_WithError(t *testing.T) {
    60  	tests := []struct {
    61  		name string
    62  		err  error
    63  	}{
    64  		{
    65  			name: "normal error",
    66  			err:  errors.New("dummy"),
    67  		},
    68  		{
    69  			name: "grpc error",
    70  			err:  status.Error(codes.DataLoss, "dummy"),
    71  		},
    72  	}
    73  
    74  	for _, test := range tests {
    75  		test := test
    76  		t.Run(test.name, func(t *testing.T) {
    77  			t.Parallel()
    78  
    79  			var wg sync.WaitGroup
    80  			wg.Add(1)
    81  			var md metadata.MD
    82  			ctx := metadata.NewIncomingContext(context.Background(), md)
    83  			_, err := UnaryTracingInterceptor(ctx, nil, &grpc.UnaryServerInfo{
    84  				FullMethod: "/",
    85  			}, func(ctx context.Context, req interface{}) (interface{}, error) {
    86  				defer wg.Done()
    87  				return nil, test.err
    88  			})
    89  			wg.Wait()
    90  			assert.Equal(t, test.err, err)
    91  		})
    92  	}
    93  }
    94  
    95  func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
    96  	var run int32
    97  	var wg sync.WaitGroup
    98  	wg.Add(1)
    99  	var md metadata.MD
   100  	ctx := metadata.NewIncomingContext(context.Background(), md)
   101  	stream := mockedServerStream{ctx: ctx}
   102  	err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
   103  		FullMethod: "/foo",
   104  	}, func(srv interface{}, stream grpc.ServerStream) error {
   105  		defer wg.Done()
   106  		atomic.AddInt32(&run, 1)
   107  		return nil
   108  	})
   109  	wg.Wait()
   110  	assert.Nil(t, err)
   111  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
   112  }
   113  
   114  func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
   115  	tests := []struct {
   116  		name string
   117  		err  error
   118  	}{
   119  		{
   120  			name: "receive event",
   121  			err:  status.Error(codes.DataLoss, "dummy"),
   122  		},
   123  		{
   124  			name: "error event",
   125  			err:  status.Error(codes.DataLoss, "dummy"),
   126  		},
   127  	}
   128  
   129  	for _, test := range tests {
   130  		test := test
   131  		t.Run(test.name, func(t *testing.T) {
   132  			t.Parallel()
   133  
   134  			var wg sync.WaitGroup
   135  			wg.Add(1)
   136  			var md metadata.MD
   137  			ctx := metadata.NewIncomingContext(context.Background(), md)
   138  			stream := mockedServerStream{ctx: ctx}
   139  			err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
   140  				FullMethod: "/foo",
   141  			}, func(srv interface{}, stream grpc.ServerStream) error {
   142  				defer wg.Done()
   143  				return test.err
   144  			})
   145  			wg.Wait()
   146  			assert.Equal(t, test.err, err)
   147  		})
   148  	}
   149  }
   150  
   151  func TestStreamTracingInterceptor_WithError(t *testing.T) {
   152  	tests := []struct {
   153  		name string
   154  		err  error
   155  	}{
   156  		{
   157  			name: "normal error",
   158  			err:  errors.New("dummy"),
   159  		},
   160  		{
   161  			name: "grpc error",
   162  			err:  status.Error(codes.DataLoss, "dummy"),
   163  		},
   164  	}
   165  
   166  	for _, test := range tests {
   167  		test := test
   168  		t.Run(test.name, func(t *testing.T) {
   169  			t.Parallel()
   170  
   171  			var wg sync.WaitGroup
   172  			wg.Add(1)
   173  			var md metadata.MD
   174  			ctx := metadata.NewIncomingContext(context.Background(), md)
   175  			stream := mockedServerStream{ctx: ctx}
   176  			err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
   177  				FullMethod: "/foo",
   178  			}, func(srv interface{}, stream grpc.ServerStream) error {
   179  				defer wg.Done()
   180  				return test.err
   181  			})
   182  			wg.Wait()
   183  			assert.Equal(t, test.err, err)
   184  		})
   185  	}
   186  }
   187  
   188  func TestClientStream_RecvMsg(t *testing.T) {
   189  	tests := []struct {
   190  		name string
   191  		err  error
   192  	}{
   193  		{
   194  			name: "nil error",
   195  		},
   196  		{
   197  			name: "EOF",
   198  			err:  io.EOF,
   199  		},
   200  		{
   201  			name: "dummy error",
   202  			err:  errors.New("dummy"),
   203  		},
   204  	}
   205  
   206  	for _, test := range tests {
   207  		test := test
   208  		t.Run(test.name, func(t *testing.T) {
   209  			t.Parallel()
   210  			stream := wrapServerStream(context.Background(), &mockedServerStream{
   211  				ctx: context.Background(),
   212  				err: test.err,
   213  			})
   214  			assert.Equal(t, test.err, stream.RecvMsg(nil))
   215  		})
   216  	}
   217  }
   218  
   219  func TestServerStream_SendMsg(t *testing.T) {
   220  	tests := []struct {
   221  		name string
   222  		err  error
   223  	}{
   224  		{
   225  			name: "nil error",
   226  		},
   227  		{
   228  			name: "with error",
   229  			err:  errors.New("dummy"),
   230  		},
   231  	}
   232  
   233  	for _, test := range tests {
   234  		test := test
   235  		t.Run(test.name, func(t *testing.T) {
   236  			t.Parallel()
   237  			stream := wrapServerStream(context.Background(), &mockedServerStream{
   238  				ctx: context.Background(),
   239  				err: test.err,
   240  			})
   241  			assert.Equal(t, test.err, stream.SendMsg(nil))
   242  		})
   243  	}
   244  }
   245  
   246  type mockedServerStream struct {
   247  	ctx context.Context
   248  	err error
   249  }
   250  
   251  func (m *mockedServerStream) SetHeader(md metadata.MD) error {
   252  	panic("implement me")
   253  }
   254  
   255  func (m *mockedServerStream) SendHeader(md metadata.MD) error {
   256  	panic("implement me")
   257  }
   258  
   259  func (m *mockedServerStream) SetTrailer(md metadata.MD) {
   260  	panic("implement me")
   261  }
   262  
   263  func (m *mockedServerStream) Context() context.Context {
   264  	if m.ctx == nil {
   265  		return context.Background()
   266  	}
   267  
   268  	return m.ctx
   269  }
   270  
   271  func (m *mockedServerStream) SendMsg(v interface{}) error {
   272  	return m.err
   273  }
   274  
   275  func (m *mockedServerStream) RecvMsg(v interface{}) error {
   276  	return m.err
   277  }