github.com/msales/pkg/v3@v3.24.0/grpcx/middleware/context_test.go (about)

     1  package middleware_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/msales/pkg/v3/grpcx/middleware"
    10  	"github.com/msales/pkg/v3/log"
    11  	"github.com/msales/pkg/v3/stats"
    12  	"github.com/stretchr/testify/assert"
    13  	"google.golang.org/grpc"
    14  )
    15  
    16  var testErr = errors.New("test: error")
    17  
    18  func TestWithUnaryServerLogger(t *testing.T) {
    19  	interceptor := middleware.WithUnaryServerLogger(log.Null)
    20  
    21  	res, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
    22  		l, ok := log.FromContext(ctx)
    23  
    24  		assert.Equal(t, l, log.Null)
    25  		assert.True(t, ok)
    26  
    27  		return "test", testErr
    28  	})
    29  
    30  	assert.Equal(t, "test", res)
    31  	assert.Equal(t, testErr, err)
    32  }
    33  
    34  func TestWithStreamServerLogger(t *testing.T) {
    35  	interceptor := middleware.WithStreamServerLogger(log.Null)
    36  	stream := &serverStreamMock{ctx: context.Background()}
    37  
    38  	err := interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
    39  		l, ok := log.FromContext(stream.Context())
    40  
    41  		assert.Equal(t, l, log.Null)
    42  		assert.True(t, ok)
    43  
    44  		return testErr
    45  	})
    46  
    47  	assert.Equal(t, testErr, err)
    48  }
    49  
    50  func TestWithUnaryServerStats(t *testing.T) {
    51  	interceptor := middleware.WithUnaryServerStats(stats.Null)
    52  
    53  	res, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
    54  		s, ok := stats.FromContext(ctx)
    55  
    56  		assert.Equal(t, s, stats.Null)
    57  		assert.True(t, ok)
    58  
    59  		return "test", testErr
    60  	})
    61  
    62  	assert.Equal(t, "test", res)
    63  	assert.Equal(t, testErr, err)
    64  }
    65  
    66  func TestWithStreamServerStats(t *testing.T) {
    67  	interceptor := middleware.WithStreamServerStats(stats.Null)
    68  	stream := &serverStreamMock{ctx: context.Background()}
    69  
    70  	err := interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
    71  		s, ok := stats.FromContext(stream.Context())
    72  
    73  		assert.Equal(t, s, stats.Null)
    74  		assert.True(t, ok)
    75  
    76  		return testErr
    77  	})
    78  
    79  	assert.Equal(t, testErr, err)
    80  }
    81  
    82  func TestWithUnaryClientContextTimeout(t *testing.T) {
    83  	ctx := context.Background()
    84  
    85  	interceptor := middleware.WithUnaryClientContextTimeout(1 * time.Hour)
    86  	err := interceptor(ctx, "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
    87  		_, ok := ctx.Deadline()
    88  
    89  		assert.True(t, ok)
    90  
    91  		return testErr
    92  	})
    93  
    94  	assert.Equal(t, testErr, err)
    95  }
    96  
    97  func TestWithUnaryClientContextTimeout_DeadlineExceeded(t *testing.T) {
    98  	ctx := context.Background()
    99  
   100  	interceptor := middleware.WithUnaryClientContextTimeout(1 * time.Nanosecond)
   101  	_ = interceptor(ctx, "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
   102  		_, ok := ctx.Deadline()
   103  
   104  		assert.True(t, ok)
   105  		time.Sleep(5 * time.Nanosecond)
   106  
   107  		assert.Error(t, ctx.Err())
   108  		assert.EqualError(t, ctx.Err(), context.DeadlineExceeded.Error())
   109  
   110  		return nil
   111  	})
   112  }
   113  
   114  func TestWithUnaryClientLogger(t *testing.T) {
   115  	interceptor := middleware.WithUnaryClientLogger(log.Null)
   116  
   117  	err := interceptor(context.Background(), "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
   118  		l, ok := log.FromContext(ctx)
   119  
   120  		assert.Equal(t, l, log.Null)
   121  		assert.True(t, ok)
   122  
   123  		return testErr
   124  	})
   125  
   126  	assert.Equal(t, testErr, err)
   127  }
   128  
   129  func TestWithUnaryClientStats(t *testing.T) {
   130  	interceptor := middleware.WithUnaryClientStats(stats.Null)
   131  
   132  	err := interceptor(context.Background(), "method", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
   133  		s, ok := stats.FromContext(ctx)
   134  
   135  		assert.Equal(t, s, stats.Null)
   136  		assert.True(t, ok)
   137  
   138  		return testErr
   139  	})
   140  
   141  	assert.Equal(t, testErr, err)
   142  }