github.com/lingyao2333/mo-zero@v1.4.1/zrpc/internal/serverinterceptors/tracinginterceptor_test.go (about)

     1  package serverinterceptors
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  
    11  	"github.com/lingyao2333/mo-zero/core/trace"
    12  	"github.com/stretchr/testify/assert"
    13  	"google.golang.org/grpc"
    14  	"google.golang.org/grpc/codes"
    15  	"google.golang.org/grpc/metadata"
    16  	"google.golang.org/grpc/status"
    17  )
    18  
    19  func TestUnaryOpenTracingInterceptor_Disable(t *testing.T) {
    20  	_, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    21  		FullMethod: "/",
    22  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    23  		return nil, nil
    24  	})
    25  	assert.Nil(t, err)
    26  }
    27  
    28  func TestUnaryOpenTracingInterceptor_Enabled(t *testing.T) {
    29  	trace.StartAgent(trace.Config{
    30  		Name:     "go-zero-test",
    31  		Endpoint: "http://localhost:14268/api/traces",
    32  		Batcher:  "jaeger",
    33  		Sampler:  1.0,
    34  	})
    35  	defer trace.StopAgent()
    36  
    37  	_, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    38  		FullMethod: "/package.TestService.GetUser",
    39  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    40  		return nil, nil
    41  	})
    42  	assert.Nil(t, err)
    43  }
    44  
    45  func TestUnaryTracingInterceptor(t *testing.T) {
    46  	var run int32
    47  	var wg sync.WaitGroup
    48  	wg.Add(1)
    49  	_, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    50  		FullMethod: "/",
    51  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    52  		defer wg.Done()
    53  		atomic.AddInt32(&run, 1)
    54  		return nil, nil
    55  	})
    56  	wg.Wait()
    57  	assert.Nil(t, err)
    58  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
    59  }
    60  
    61  func TestUnaryTracingInterceptor_WithError(t *testing.T) {
    62  	tests := []struct {
    63  		name string
    64  		err  error
    65  	}{
    66  		{
    67  			name: "normal error",
    68  			err:  errors.New("dummy"),
    69  		},
    70  		{
    71  			name: "grpc error",
    72  			err:  status.Error(codes.DataLoss, "dummy"),
    73  		},
    74  	}
    75  
    76  	for _, test := range tests {
    77  		test := test
    78  		t.Run(test.name, func(t *testing.T) {
    79  			t.Parallel()
    80  
    81  			var wg sync.WaitGroup
    82  			wg.Add(1)
    83  			var md metadata.MD
    84  			ctx := metadata.NewIncomingContext(context.Background(), md)
    85  			_, err := UnaryTracingInterceptor(ctx, nil, &grpc.UnaryServerInfo{
    86  				FullMethod: "/",
    87  			}, func(ctx context.Context, req interface{}) (interface{}, error) {
    88  				defer wg.Done()
    89  				return nil, test.err
    90  			})
    91  			wg.Wait()
    92  			assert.Equal(t, test.err, err)
    93  		})
    94  	}
    95  }
    96  
    97  func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) {
    98  	var run int32
    99  	var wg sync.WaitGroup
   100  	wg.Add(1)
   101  	var md metadata.MD
   102  	ctx := metadata.NewIncomingContext(context.Background(), md)
   103  	stream := mockedServerStream{ctx: ctx}
   104  	err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
   105  		FullMethod: "/foo",
   106  	}, func(svr interface{}, stream grpc.ServerStream) error {
   107  		defer wg.Done()
   108  		atomic.AddInt32(&run, 1)
   109  		return nil
   110  	})
   111  	wg.Wait()
   112  	assert.Nil(t, err)
   113  	assert.Equal(t, int32(1), atomic.LoadInt32(&run))
   114  }
   115  
   116  func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) {
   117  	tests := []struct {
   118  		name string
   119  		err  error
   120  	}{
   121  		{
   122  			name: "receive event",
   123  			err:  status.Error(codes.DataLoss, "dummy"),
   124  		},
   125  		{
   126  			name: "error event",
   127  			err:  status.Error(codes.DataLoss, "dummy"),
   128  		},
   129  	}
   130  
   131  	for _, test := range tests {
   132  		test := test
   133  		t.Run(test.name, func(t *testing.T) {
   134  			t.Parallel()
   135  
   136  			var wg sync.WaitGroup
   137  			wg.Add(1)
   138  			var md metadata.MD
   139  			ctx := metadata.NewIncomingContext(context.Background(), md)
   140  			stream := mockedServerStream{ctx: ctx}
   141  			err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
   142  				FullMethod: "/foo",
   143  			}, func(svr interface{}, stream grpc.ServerStream) error {
   144  				defer wg.Done()
   145  				return test.err
   146  			})
   147  			wg.Wait()
   148  			assert.Equal(t, test.err, err)
   149  		})
   150  	}
   151  }
   152  
   153  func TestStreamTracingInterceptor_WithError(t *testing.T) {
   154  	tests := []struct {
   155  		name string
   156  		err  error
   157  	}{
   158  		{
   159  			name: "normal error",
   160  			err:  errors.New("dummy"),
   161  		},
   162  		{
   163  			name: "grpc error",
   164  			err:  status.Error(codes.DataLoss, "dummy"),
   165  		},
   166  	}
   167  
   168  	for _, test := range tests {
   169  		test := test
   170  		t.Run(test.name, func(t *testing.T) {
   171  			t.Parallel()
   172  
   173  			var wg sync.WaitGroup
   174  			wg.Add(1)
   175  			var md metadata.MD
   176  			ctx := metadata.NewIncomingContext(context.Background(), md)
   177  			stream := mockedServerStream{ctx: ctx}
   178  			err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{
   179  				FullMethod: "/foo",
   180  			}, func(svr interface{}, stream grpc.ServerStream) error {
   181  				defer wg.Done()
   182  				return test.err
   183  			})
   184  			wg.Wait()
   185  			assert.Equal(t, test.err, err)
   186  		})
   187  	}
   188  }
   189  
   190  func TestClientStream_RecvMsg(t *testing.T) {
   191  	tests := []struct {
   192  		name string
   193  		err  error
   194  	}{
   195  		{
   196  			name: "nil error",
   197  		},
   198  		{
   199  			name: "EOF",
   200  			err:  io.EOF,
   201  		},
   202  		{
   203  			name: "dummy error",
   204  			err:  errors.New("dummy"),
   205  		},
   206  	}
   207  
   208  	for _, test := range tests {
   209  		test := test
   210  		t.Run(test.name, func(t *testing.T) {
   211  			t.Parallel()
   212  			stream := wrapServerStream(context.Background(), &mockedServerStream{
   213  				ctx: context.Background(),
   214  				err: test.err,
   215  			})
   216  			assert.Equal(t, test.err, stream.RecvMsg(nil))
   217  		})
   218  	}
   219  }
   220  
   221  func TestServerStream_SendMsg(t *testing.T) {
   222  	tests := []struct {
   223  		name string
   224  		err  error
   225  	}{
   226  		{
   227  			name: "nil error",
   228  		},
   229  		{
   230  			name: "with error",
   231  			err:  errors.New("dummy"),
   232  		},
   233  	}
   234  
   235  	for _, test := range tests {
   236  		test := test
   237  		t.Run(test.name, func(t *testing.T) {
   238  			t.Parallel()
   239  			stream := wrapServerStream(context.Background(), &mockedServerStream{
   240  				ctx: context.Background(),
   241  				err: test.err,
   242  			})
   243  			assert.Equal(t, test.err, stream.SendMsg(nil))
   244  		})
   245  	}
   246  }
   247  
   248  type mockedServerStream struct {
   249  	ctx context.Context
   250  	err error
   251  }
   252  
   253  func (m *mockedServerStream) SetHeader(md metadata.MD) error {
   254  	panic("implement me")
   255  }
   256  
   257  func (m *mockedServerStream) SendHeader(md metadata.MD) error {
   258  	panic("implement me")
   259  }
   260  
   261  func (m *mockedServerStream) SetTrailer(md metadata.MD) {
   262  	panic("implement me")
   263  }
   264  
   265  func (m *mockedServerStream) Context() context.Context {
   266  	if m.ctx == nil {
   267  		return context.Background()
   268  	}
   269  
   270  	return m.ctx
   271  }
   272  
   273  func (m *mockedServerStream) SendMsg(v interface{}) error {
   274  	return m.err
   275  }
   276  
   277  func (m *mockedServerStream) RecvMsg(v interface{}) error {
   278  	return m.err
   279  }