github.com/shuguocloud/go-zero@v1.3.0/zrpc/internal/serverinterceptors/tracinginterceptor_test.go (about) 1 package serverinterceptors 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "sync" 8 "sync/atomic" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 "github.com/shuguocloud/go-zero/core/trace" 13 "google.golang.org/grpc" 14 "google.golang.org/grpc/codes" 15 "google.golang.org/grpc/metadata" 16 "google.golang.org/grpc/status" 17 ) 18 19 func TestUnaryOpenTracingInterceptor_Disable(t *testing.T) { 20 _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{ 21 FullMethod: "/", 22 }, func(ctx context.Context, req interface{}) (interface{}, error) { 23 return nil, nil 24 }) 25 assert.Nil(t, err) 26 } 27 28 func TestUnaryOpenTracingInterceptor_Enabled(t *testing.T) { 29 trace.StartAgent(trace.Config{ 30 Name: "go-zero-test", 31 Endpoint: "http://localhost:14268/api/traces", 32 Batcher: "jaeger", 33 Sampler: 1.0, 34 }) 35 _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{ 36 FullMethod: "/package.TestService.GetUser", 37 }, func(ctx context.Context, req interface{}) (interface{}, error) { 38 return nil, nil 39 }) 40 assert.Nil(t, err) 41 } 42 43 func TestUnaryTracingInterceptor(t *testing.T) { 44 var run int32 45 var wg sync.WaitGroup 46 wg.Add(1) 47 _, err := UnaryTracingInterceptor(context.Background(), nil, &grpc.UnaryServerInfo{ 48 FullMethod: "/", 49 }, func(ctx context.Context, req interface{}) (interface{}, error) { 50 defer wg.Done() 51 atomic.AddInt32(&run, 1) 52 return nil, nil 53 }) 54 wg.Wait() 55 assert.Nil(t, err) 56 assert.Equal(t, int32(1), atomic.LoadInt32(&run)) 57 } 58 59 func TestUnaryTracingInterceptor_WithError(t *testing.T) { 60 tests := []struct { 61 name string 62 err error 63 }{ 64 { 65 name: "normal error", 66 err: errors.New("dummy"), 67 }, 68 { 69 name: "grpc error", 70 err: status.Error(codes.DataLoss, "dummy"), 71 }, 72 } 73 74 for _, test := range tests { 75 test := test 76 t.Run(test.name, func(t *testing.T) { 77 t.Parallel() 78 79 var wg sync.WaitGroup 80 wg.Add(1) 81 var md metadata.MD 82 ctx := metadata.NewIncomingContext(context.Background(), md) 83 _, err := UnaryTracingInterceptor(ctx, nil, &grpc.UnaryServerInfo{ 84 FullMethod: "/", 85 }, func(ctx context.Context, req interface{}) (interface{}, error) { 86 defer wg.Done() 87 return nil, test.err 88 }) 89 wg.Wait() 90 assert.Equal(t, test.err, err) 91 }) 92 } 93 } 94 95 func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) { 96 var run int32 97 var wg sync.WaitGroup 98 wg.Add(1) 99 var md metadata.MD 100 ctx := metadata.NewIncomingContext(context.Background(), md) 101 stream := mockedServerStream{ctx: ctx} 102 err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ 103 FullMethod: "/foo", 104 }, func(srv interface{}, stream grpc.ServerStream) error { 105 defer wg.Done() 106 atomic.AddInt32(&run, 1) 107 return nil 108 }) 109 wg.Wait() 110 assert.Nil(t, err) 111 assert.Equal(t, int32(1), atomic.LoadInt32(&run)) 112 } 113 114 func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) { 115 tests := []struct { 116 name string 117 err error 118 }{ 119 { 120 name: "receive event", 121 err: status.Error(codes.DataLoss, "dummy"), 122 }, 123 { 124 name: "error event", 125 err: status.Error(codes.DataLoss, "dummy"), 126 }, 127 } 128 129 for _, test := range tests { 130 test := test 131 t.Run(test.name, func(t *testing.T) { 132 t.Parallel() 133 134 var wg sync.WaitGroup 135 wg.Add(1) 136 var md metadata.MD 137 ctx := metadata.NewIncomingContext(context.Background(), md) 138 stream := mockedServerStream{ctx: ctx} 139 err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ 140 FullMethod: "/foo", 141 }, func(srv interface{}, stream grpc.ServerStream) error { 142 defer wg.Done() 143 return test.err 144 }) 145 wg.Wait() 146 assert.Equal(t, test.err, err) 147 }) 148 } 149 } 150 151 func TestStreamTracingInterceptor_WithError(t *testing.T) { 152 tests := []struct { 153 name string 154 err error 155 }{ 156 { 157 name: "normal error", 158 err: errors.New("dummy"), 159 }, 160 { 161 name: "grpc error", 162 err: status.Error(codes.DataLoss, "dummy"), 163 }, 164 } 165 166 for _, test := range tests { 167 test := test 168 t.Run(test.name, func(t *testing.T) { 169 t.Parallel() 170 171 var wg sync.WaitGroup 172 wg.Add(1) 173 var md metadata.MD 174 ctx := metadata.NewIncomingContext(context.Background(), md) 175 stream := mockedServerStream{ctx: ctx} 176 err := StreamTracingInterceptor(nil, &stream, &grpc.StreamServerInfo{ 177 FullMethod: "/foo", 178 }, func(srv interface{}, stream grpc.ServerStream) error { 179 defer wg.Done() 180 return test.err 181 }) 182 wg.Wait() 183 assert.Equal(t, test.err, err) 184 }) 185 } 186 } 187 188 func TestClientStream_RecvMsg(t *testing.T) { 189 tests := []struct { 190 name string 191 err error 192 }{ 193 { 194 name: "nil error", 195 }, 196 { 197 name: "EOF", 198 err: io.EOF, 199 }, 200 { 201 name: "dummy error", 202 err: errors.New("dummy"), 203 }, 204 } 205 206 for _, test := range tests { 207 test := test 208 t.Run(test.name, func(t *testing.T) { 209 t.Parallel() 210 stream := wrapServerStream(context.Background(), &mockedServerStream{ 211 ctx: context.Background(), 212 err: test.err, 213 }) 214 assert.Equal(t, test.err, stream.RecvMsg(nil)) 215 }) 216 } 217 } 218 219 func TestServerStream_SendMsg(t *testing.T) { 220 tests := []struct { 221 name string 222 err error 223 }{ 224 { 225 name: "nil error", 226 }, 227 { 228 name: "with error", 229 err: errors.New("dummy"), 230 }, 231 } 232 233 for _, test := range tests { 234 test := test 235 t.Run(test.name, func(t *testing.T) { 236 t.Parallel() 237 stream := wrapServerStream(context.Background(), &mockedServerStream{ 238 ctx: context.Background(), 239 err: test.err, 240 }) 241 assert.Equal(t, test.err, stream.SendMsg(nil)) 242 }) 243 } 244 } 245 246 type mockedServerStream struct { 247 ctx context.Context 248 err error 249 } 250 251 func (m *mockedServerStream) SetHeader(md metadata.MD) error { 252 panic("implement me") 253 } 254 255 func (m *mockedServerStream) SendHeader(md metadata.MD) error { 256 panic("implement me") 257 } 258 259 func (m *mockedServerStream) SetTrailer(md metadata.MD) { 260 panic("implement me") 261 } 262 263 func (m *mockedServerStream) Context() context.Context { 264 if m.ctx == nil { 265 return context.Background() 266 } 267 268 return m.ctx 269 } 270 271 func (m *mockedServerStream) SendMsg(v interface{}) error { 272 return m.err 273 } 274 275 func (m *mockedServerStream) RecvMsg(v interface{}) error { 276 return m.err 277 }