github.com/lingyao2333/mo-zero@v1.4.1/zrpc/internal/serverinterceptors/tracinginterceptor_test.go (about) 1 package serverinterceptors 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "sync" 8 "sync/atomic" 9 "testing" 10 11 "github.com/lingyao2333/mo-zero/core/trace" 12 "github.com/stretchr/testify/assert" 13 "google.golang.org/grpc" 14 "google.golang.org/grpc/codes" 15 "google.golang.org/grpc/metadata" 16 "google.golang.org/grpc/status" 17 ) 18 19 func TestUnaryOpenTracingInterceptor_Disable(t *testing.T) { 20 _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{ 21 FullMethod: "/", 22 }, func(ctx context.Context, req interface{}) (interface{}, error) { 23 return nil, nil 24 }) 25 assert.Nil(t, err) 26 } 27 28 func TestUnaryOpenTracingInterceptor_Enabled(t *testing.T) { 29 trace.StartAgent(trace.Config{ 30 Name: "go-zero-test", 31 Endpoint: "http://localhost:14268/api/traces", 32 Batcher: "jaeger", 33 Sampler: 1.0, 34 }) 35 defer trace.StopAgent() 36 37 _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{ 38 FullMethod: "/package.TestService.GetUser", 39 }, func(ctx context.Context, req interface{}) (interface{}, error) { 40 return nil, nil 41 }) 42 assert.Nil(t, err) 43 } 44 45 func TestUnaryTracingInterceptor(t *testing.T) { 46 var run int32 47 var wg sync.WaitGroup 48 wg.Add(1) 49 _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{ 50 FullMethod: "/", 51 }, func(ctx context.Context, req interface{}) (interface{}, error) { 52 defer wg.Done() 53 atomic.AddInt32(&run, 1) 54 return nil, nil 55 }) 56 wg.Wait() 57 assert.Nil(t, err) 58 assert.Equal(t, int32(1), atomic.LoadInt32(&run)) 59 } 60 61 func TestUnaryTracingInterceptor_WithError(t *testing.T) { 62 tests := []struct { 63 name string 64 err error 65 }{ 66 { 67 name: "normal error", 68 err: errors.New("dummy"), 69 }, 70 { 71 name: "grpc error", 72 err: status.Error(codes.DataLoss, "dummy"), 73 }, 74 } 75 76 for _, test := range tests { 77 test := test 78 t.Run(test.name, func(t *testing.T) { 79 t.Parallel() 80 81 var wg sync.WaitGroup 82 wg.Add(1) 83 var md metadata.MD 84 ctx := metadata.NewIncomingContext(context.Background(), md) 85 _, err := UnaryTracingInterceptor(ctx, nil, &grpc.UnaryServerInfo{ 86 FullMethod: "/", 87 }, func(ctx context.Context, req interface{}) (interface{}, error) { 88 defer wg.Done() 89 return nil, test.err 90 }) 91 wg.Wait() 92 assert.Equal(t, test.err, err) 93 }) 94 } 95 } 96 97 func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) { 98 var run int32 99 var wg sync.WaitGroup 100 wg.Add(1) 101 var md metadata.MD 102 ctx := metadata.NewIncomingContext(context.Background(), md) 103 stream := mockedServerStream{ctx: ctx} 104 err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ 105 FullMethod: "/foo", 106 }, func(svr interface{}, stream grpc.ServerStream) error { 107 defer wg.Done() 108 atomic.AddInt32(&run, 1) 109 return nil 110 }) 111 wg.Wait() 112 assert.Nil(t, err) 113 assert.Equal(t, int32(1), atomic.LoadInt32(&run)) 114 } 115 116 func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) { 117 tests := []struct { 118 name string 119 err error 120 }{ 121 { 122 name: "receive event", 123 err: status.Error(codes.DataLoss, "dummy"), 124 }, 125 { 126 name: "error event", 127 err: status.Error(codes.DataLoss, "dummy"), 128 }, 129 } 130 131 for _, test := range tests { 132 test := test 133 t.Run(test.name, func(t *testing.T) { 134 t.Parallel() 135 136 var wg sync.WaitGroup 137 wg.Add(1) 138 var md metadata.MD 139 ctx := metadata.NewIncomingContext(context.Background(), md) 140 stream := mockedServerStream{ctx: ctx} 141 err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ 142 FullMethod: "/foo", 143 }, func(svr interface{}, stream grpc.ServerStream) error { 144 defer wg.Done() 145 return test.err 146 }) 147 wg.Wait() 148 assert.Equal(t, test.err, err) 149 }) 150 } 151 } 152 153 func TestStreamTracingInterceptor_WithError(t *testing.T) { 154 tests := []struct { 155 name string 156 err error 157 }{ 158 { 159 name: "normal error", 160 err: errors.New("dummy"), 161 }, 162 { 163 name: "grpc error", 164 err: status.Error(codes.DataLoss, "dummy"), 165 }, 166 } 167 168 for _, test := range tests { 169 test := test 170 t.Run(test.name, func(t *testing.T) { 171 t.Parallel() 172 173 var wg sync.WaitGroup 174 wg.Add(1) 175 var md metadata.MD 176 ctx := metadata.NewIncomingContext(context.Background(), md) 177 stream := mockedServerStream{ctx: ctx} 178 err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ 179 FullMethod: "/foo", 180 }, func(svr interface{}, stream grpc.ServerStream) error { 181 defer wg.Done() 182 return test.err 183 }) 184 wg.Wait() 185 assert.Equal(t, test.err, err) 186 }) 187 } 188 } 189 190 func TestClientStream_RecvMsg(t *testing.T) { 191 tests := []struct { 192 name string 193 err error 194 }{ 195 { 196 name: "nil error", 197 }, 198 { 199 name: "EOF", 200 err: io.EOF, 201 }, 202 { 203 name: "dummy error", 204 err: errors.New("dummy"), 205 }, 206 } 207 208 for _, test := range tests { 209 test := test 210 t.Run(test.name, func(t *testing.T) { 211 t.Parallel() 212 stream := wrapServerStream(context.Background(), &mockedServerStream{ 213 ctx: context.Background(), 214 err: test.err, 215 }) 216 assert.Equal(t, test.err, stream.RecvMsg(nil)) 217 }) 218 } 219 } 220 221 func TestServerStream_SendMsg(t *testing.T) { 222 tests := []struct { 223 name string 224 err error 225 }{ 226 { 227 name: "nil error", 228 }, 229 { 230 name: "with error", 231 err: errors.New("dummy"), 232 }, 233 } 234 235 for _, test := range tests { 236 test := test 237 t.Run(test.name, func(t *testing.T) { 238 t.Parallel() 239 stream := wrapServerStream(context.Background(), &mockedServerStream{ 240 ctx: context.Background(), 241 err: test.err, 242 }) 243 assert.Equal(t, test.err, stream.SendMsg(nil)) 244 }) 245 } 246 } 247 248 type mockedServerStream struct { 249 ctx context.Context 250 err error 251 } 252 253 func (m *mockedServerStream) SetHeader(md metadata.MD) error { 254 panic("implement me") 255 } 256 257 func (m *mockedServerStream) SendHeader(md metadata.MD) error { 258 panic("implement me") 259 } 260 261 func (m *mockedServerStream) SetTrailer(md metadata.MD) { 262 panic("implement me") 263 } 264 265 func (m *mockedServerStream) Context() context.Context { 266 if m.ctx == nil { 267 return context.Background() 268 } 269 270 return m.ctx 271 } 272 273 func (m *mockedServerStream) SendMsg(v interface{}) error { 274 return m.err 275 } 276 277 func (m *mockedServerStream) RecvMsg(v interface{}) error { 278 return m.err 279 }