gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/go-grpc-middleware/chain_test.go (about)

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_middleware
     5  
     6  import (
     7  	"fmt"
     8  	"testing"
     9  
    10  	"gitee.com/ks-custle/core-gm/grpc"
    11  	"gitee.com/ks-custle/core-gm/grpc/metadata"
    12  	"gitee.com/ks-custle/core-gm/net/context"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  var (
    17  	someServiceName  = "SomeService.StreamMethod"
    18  	parentUnaryInfo  = &grpc.UnaryServerInfo{FullMethod: someServiceName}
    19  	parentStreamInfo = &grpc.StreamServerInfo{
    20  		FullMethod:     someServiceName,
    21  		IsServerStream: true,
    22  	}
    23  	someValue     = 1
    24  	parentContext = context.WithValue(context.TODO(), "parent", someValue)
    25  )
    26  
    27  func TestChainUnaryServer(t *testing.T) {
    28  	input := "input"
    29  	output := "output"
    30  
    31  	first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    32  		requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value")
    33  		require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo")
    34  		ctx = context.WithValue(ctx, "first", 1)
    35  		return handler(ctx, req)
    36  	}
    37  	second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    38  		requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value")
    39  		requireContextValue(t, ctx, "first", "second interceptor must know the first context value")
    40  		require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo")
    41  		ctx = context.WithValue(ctx, "second", 1)
    42  		return handler(ctx, req)
    43  	}
    44  	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
    45  		require.EqualValues(t, input, req, "handler must get the input")
    46  		requireContextValue(t, ctx, "parent", "handler must know the parent context value")
    47  		requireContextValue(t, ctx, "first", "handler must know the first context value")
    48  		requireContextValue(t, ctx, "second", "handler must know the second context value")
    49  		return output, nil
    50  	}
    51  
    52  	chain := ChainUnaryServer(first, second)
    53  	out, _ := chain(parentContext, input, parentUnaryInfo, handler)
    54  	require.EqualValues(t, output, out, "chain must return handler's output")
    55  }
    56  
    57  func TestChainStreamServer(t *testing.T) {
    58  	someService := &struct{}{}
    59  	recvMessage := "received"
    60  	sentMessage := "sent"
    61  	outputError := fmt.Errorf("some error")
    62  
    63  	first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    64  		requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value")
    65  		require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo")
    66  		require.Equal(t, someService, srv, "first interceptor must know someService")
    67  		wrapped := WrapServerStream(stream)
    68  		wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1)
    69  		return handler(srv, wrapped)
    70  	}
    71  	second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    72  		requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value")
    73  		requireContextValue(t, stream.Context(), "parent", "second interceptor must know the first context value")
    74  		require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo")
    75  		require.Equal(t, someService, srv, "second interceptor must know someService")
    76  		wrapped := WrapServerStream(stream)
    77  		wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1)
    78  		return handler(srv, wrapped)
    79  	}
    80  	handler := func(srv interface{}, stream grpc.ServerStream) error {
    81  		require.Equal(t, someService, srv, "handler must know someService")
    82  		requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value")
    83  		requireContextValue(t, stream.Context(), "first", "handler must know the first context value")
    84  		requireContextValue(t, stream.Context(), "second", "handler must know the second context value")
    85  		require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages")
    86  		require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages")
    87  		return outputError
    88  	}
    89  	fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage}
    90  	chain := ChainStreamServer(first, second)
    91  	err := chain(someService, fakeStream, parentStreamInfo, handler)
    92  	require.Equal(t, outputError, err, "chain must return handler's error")
    93  	require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream")
    94  }
    95  
    96  func TestChainUnaryClient(t *testing.T) {
    97  	ignoredMd := metadata.Pairs("foo", "bar")
    98  	parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
    99  	reqMessage := "request"
   100  	replyMessage := "reply"
   101  	outputError := fmt.Errorf("some error")
   102  
   103  	first := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   104  		requireContextValue(t, ctx, "parent", "first must know the parent context value")
   105  		require.Equal(t, someServiceName, method, "first must know someService")
   106  		require.Len(t, opts, 1, "first should see parent CallOptions")
   107  		wrappedCtx := context.WithValue(ctx, "first", 1)
   108  		return invoker(wrappedCtx, method, req, reply, cc, opts...)
   109  	}
   110  	second := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   111  		requireContextValue(t, ctx, "parent", "second must know the parent context value")
   112  		require.Equal(t, someServiceName, method, "second must know someService")
   113  		require.Len(t, opts, 1, "second should see parent CallOptions")
   114  		wrappedOpts := append(opts, grpc.FailFast(true))
   115  		wrappedCtx := context.WithValue(ctx, "second", 1)
   116  		return invoker(wrappedCtx, method, req, reply, cc, wrappedOpts...)
   117  	}
   118  	invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
   119  		require.Equal(t, someServiceName, method, "invoker must know someService")
   120  		requireContextValue(t, ctx, "parent", "invoker must know the parent context value")
   121  		requireContextValue(t, ctx, "first", "invoker must know the first context value")
   122  		requireContextValue(t, ctx, "second", "invoker must know the second context value")
   123  		require.Len(t, opts, 2, "invoker should see both CallOpts from second and parent")
   124  		return outputError
   125  	}
   126  	chain := ChainUnaryClient(first, second)
   127  	err := chain(parentContext, someServiceName, reqMessage, replyMessage, nil, invoker, parentOpts...)
   128  	require.Equal(t, outputError, err, "chain must return invokers's error")
   129  }
   130  
   131  func TestChainStreamClient(t *testing.T) {
   132  	ignoredMd := metadata.Pairs("foo", "bar")
   133  	parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
   134  	clientStream := &fakeClientStream{}
   135  	fakeStreamDesc := &grpc.StreamDesc{ClientStreams: true, ServerStreams: true, StreamName: someServiceName}
   136  
   137  	first := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   138  		requireContextValue(t, ctx, "parent", "first must know the parent context value")
   139  		require.Equal(t, someServiceName, method, "first must know someService")
   140  		require.Len(t, opts, 1, "first should see parent CallOptions")
   141  		wrappedCtx := context.WithValue(ctx, "first", 1)
   142  		return streamer(wrappedCtx, desc, cc, method, opts...)
   143  	}
   144  	second := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   145  		requireContextValue(t, ctx, "parent", "second must know the parent context value")
   146  		require.Equal(t, someServiceName, method, "second must know someService")
   147  		require.Len(t, opts, 1, "second should see parent CallOptions")
   148  		wrappedOpts := append(opts, grpc.FailFast(true))
   149  		wrappedCtx := context.WithValue(ctx, "second", 1)
   150  		return streamer(wrappedCtx, desc, cc, method, wrappedOpts...)
   151  	}
   152  	streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   153  		require.Equal(t, someServiceName, method, "streamer must know someService")
   154  		require.Equal(t, fakeStreamDesc, desc, "streamer must see the right StreamDesc")
   155  
   156  		requireContextValue(t, ctx, "parent", "streamer must know the parent context value")
   157  		requireContextValue(t, ctx, "first", "streamer must know the first context value")
   158  		requireContextValue(t, ctx, "second", "streamer must know the second context value")
   159  		require.Len(t, opts, 2, "streamer should see both CallOpts from second and parent")
   160  		return clientStream, nil
   161  	}
   162  	chain := ChainStreamClient(first, second)
   163  	someStream, err := chain(parentContext, fakeStreamDesc, nil, someServiceName, streamer, parentOpts...)
   164  	require.NoError(t, err, "chain must not return an error as nothing there reutrned it")
   165  	require.Equal(t, clientStream, someStream, "chain must return invokers's clientstream")
   166  }
   167  
   168  func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) {
   169  	val := ctx.Value(key)
   170  	require.NotNil(t, val, msg...)
   171  	require.Equal(t, someValue, val, msg...)
   172  }