github.com/msales/pkg/v3@v3.24.0/grpcx/middleware/context_test.go (about) 1 package middleware_test 2 3 import ( 4 "context" 5 "errors" 6 "testing" 7 "time" 8 9 "github.com/msales/pkg/v3/grpcx/middleware" 10 "github.com/msales/pkg/v3/log" 11 "github.com/msales/pkg/v3/stats" 12 "github.com/stretchr/testify/assert" 13 "google.golang.org/grpc" 14 ) 15 16 var testErr = errors.New("test: error") 17 18 func TestWithUnaryServerLogger(t *testing.T) { 19 interceptor := middleware.WithUnaryServerLogger(log.Null) 20 21 res, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { 22 l, ok := log.FromContext(ctx) 23 24 assert.Equal(t, l, log.Null) 25 assert.True(t, ok) 26 27 return "test", testErr 28 }) 29 30 assert.Equal(t, "test", res) 31 assert.Equal(t, testErr, err) 32 } 33 34 func TestWithStreamServerLogger(t *testing.T) { 35 interceptor := middleware.WithStreamServerLogger(log.Null) 36 stream := &serverStreamMock{ctx: context.Background()} 37 38 err := interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error { 39 l, ok := log.FromContext(stream.Context()) 40 41 assert.Equal(t, l, log.Null) 42 assert.True(t, ok) 43 44 return testErr 45 }) 46 47 assert.Equal(t, testErr, err) 48 } 49 50 func TestWithUnaryServerStats(t *testing.T) { 51 interceptor := middleware.WithUnaryServerStats(stats.Null) 52 53 res, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { 54 s, ok := stats.FromContext(ctx) 55 56 assert.Equal(t, s, stats.Null) 57 assert.True(t, ok) 58 59 return "test", testErr 60 }) 61 62 assert.Equal(t, "test", res) 63 assert.Equal(t, testErr, err) 64 } 65 66 func TestWithStreamServerStats(t *testing.T) { 67 interceptor := middleware.WithStreamServerStats(stats.Null) 68 stream := &serverStreamMock{ctx: context.Background()} 69 70 err := interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error { 71 s, ok := stats.FromContext(stream.Context()) 72 73 assert.Equal(t, s, stats.Null) 74 assert.True(t, ok) 75 76 return testErr 77 }) 78 79 assert.Equal(t, testErr, err) 80 } 81 82 func TestWithUnaryClientContextTimeout(t *testing.T) { 83 ctx := context.Background() 84 85 interceptor := middleware.WithUnaryClientContextTimeout(1 * time.Hour) 86 err := interceptor(ctx, "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 87 _, ok := ctx.Deadline() 88 89 assert.True(t, ok) 90 91 return testErr 92 }) 93 94 assert.Equal(t, testErr, err) 95 } 96 97 func TestWithUnaryClientContextTimeout_DeadlineExceeded(t *testing.T) { 98 ctx := context.Background() 99 100 interceptor := middleware.WithUnaryClientContextTimeout(1 * time.Nanosecond) 101 _ = interceptor(ctx, "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 102 _, ok := ctx.Deadline() 103 104 assert.True(t, ok) 105 time.Sleep(5 * time.Nanosecond) 106 107 assert.Error(t, ctx.Err()) 108 assert.EqualError(t, ctx.Err(), context.DeadlineExceeded.Error()) 109 110 return nil 111 }) 112 } 113 114 func TestWithUnaryClientLogger(t *testing.T) { 115 interceptor := middleware.WithUnaryClientLogger(log.Null) 116 117 err := interceptor(context.Background(), "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 118 l, ok := log.FromContext(ctx) 119 120 assert.Equal(t, l, log.Null) 121 assert.True(t, ok) 122 123 return testErr 124 }) 125 126 assert.Equal(t, testErr, err) 127 } 128 129 func TestWithUnaryClientStats(t *testing.T) { 130 interceptor := middleware.WithUnaryClientStats(stats.Null) 131 132 err := interceptor(context.Background(), "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 133 s, ok := stats.FromContext(ctx) 134 135 assert.Equal(t, s, stats.Null) 136 assert.True(t, ok) 137 138 return testErr 139 }) 140 141 assert.Equal(t, testErr, err) 142 }