gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/test/server_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"testing"
    25  
    26  	grpc "gitee.com/ks-custle/core-gm/grpc"
    27  	"gitee.com/ks-custle/core-gm/grpc/codes"
    28  	"gitee.com/ks-custle/core-gm/grpc/internal/stubserver"
    29  	"gitee.com/ks-custle/core-gm/grpc/status"
    30  	testpb "gitee.com/ks-custle/core-gm/grpc/test/grpc_testing"
    31  )
    32  
    33  type ctxKey string
    34  
    35  func (s) TestChainUnaryServerInterceptor(t *testing.T) {
    36  	var (
    37  		firstIntKey  = ctxKey("firstIntKey")
    38  		secondIntKey = ctxKey("secondIntKey")
    39  	)
    40  
    41  	firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    42  		if ctx.Value(firstIntKey) != nil {
    43  			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey)
    44  		}
    45  		if ctx.Value(secondIntKey) != nil {
    46  			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey)
    47  		}
    48  
    49  		firstCtx := context.WithValue(ctx, firstIntKey, 0)
    50  		resp, err := handler(firstCtx, req)
    51  		if err != nil {
    52  			return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt")
    53  		}
    54  
    55  		simpleResp, ok := resp.(*testpb.SimpleResponse)
    56  		if !ok {
    57  			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt")
    58  		}
    59  		return &testpb.SimpleResponse{
    60  			Payload: &testpb.Payload{
    61  				Type: simpleResp.GetPayload().GetType(),
    62  				Body: append(simpleResp.GetPayload().GetBody(), '1'),
    63  			},
    64  		}, nil
    65  	}
    66  
    67  	secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    68  		if ctx.Value(firstIntKey) == nil {
    69  			return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey)
    70  		}
    71  		if ctx.Value(secondIntKey) != nil {
    72  			return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey)
    73  		}
    74  
    75  		secondCtx := context.WithValue(ctx, secondIntKey, 1)
    76  		resp, err := handler(secondCtx, req)
    77  		if err != nil {
    78  			return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt")
    79  		}
    80  
    81  		simpleResp, ok := resp.(*testpb.SimpleResponse)
    82  		if !ok {
    83  			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt")
    84  		}
    85  		return &testpb.SimpleResponse{
    86  			Payload: &testpb.Payload{
    87  				Type: simpleResp.GetPayload().GetType(),
    88  				Body: append(simpleResp.GetPayload().GetBody(), '2'),
    89  			},
    90  		}, nil
    91  	}
    92  
    93  	lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    94  		if ctx.Value(firstIntKey) == nil {
    95  			return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey)
    96  		}
    97  		if ctx.Value(secondIntKey) == nil {
    98  			return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey)
    99  		}
   100  
   101  		resp, err := handler(ctx, req)
   102  		if err != nil {
   103  			return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt")
   104  		}
   105  
   106  		simpleResp, ok := resp.(*testpb.SimpleResponse)
   107  		if !ok {
   108  			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt")
   109  		}
   110  		return &testpb.SimpleResponse{
   111  			Payload: &testpb.Payload{
   112  				Type: simpleResp.GetPayload().GetType(),
   113  				Body: append(simpleResp.GetPayload().GetBody(), '3'),
   114  			},
   115  		}, nil
   116  	}
   117  
   118  	sopts := []grpc.ServerOption{
   119  		grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt),
   120  	}
   121  
   122  	ss := &stubserver.StubServer{
   123  		UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   124  			payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0)
   125  			if err != nil {
   126  				return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err)
   127  			}
   128  
   129  			return &testpb.SimpleResponse{
   130  				Payload: payload,
   131  			}, nil
   132  		},
   133  	}
   134  	if err := ss.Start(sopts); err != nil {
   135  		t.Fatalf("Error starting endpoint server: %v", err)
   136  	}
   137  	defer ss.Stop()
   138  
   139  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   140  	defer cancel()
   141  	resp, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{})
   142  	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
   143  		t.Fatalf("ss.Client.UnaryCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
   144  	}
   145  
   146  	respBytes := resp.Payload.GetBody()
   147  	if string(respBytes) != "321" {
   148  		t.Fatalf("invalid response: want=%s, but got=%s", "321", resp)
   149  	}
   150  }
   151  
   152  func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) {
   153  	baseIntKey := ctxKey("baseIntKey")
   154  
   155  	baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   156  		if ctx.Value(baseIntKey) != nil {
   157  			return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey)
   158  		}
   159  
   160  		baseCtx := context.WithValue(ctx, baseIntKey, 1)
   161  		return handler(baseCtx, req)
   162  	}
   163  
   164  	chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   165  		if ctx.Value(baseIntKey) == nil {
   166  			return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey)
   167  		}
   168  
   169  		return handler(ctx, req)
   170  	}
   171  
   172  	sopts := []grpc.ServerOption{
   173  		grpc.UnaryInterceptor(baseInt),
   174  		grpc.ChainUnaryInterceptor(chainInt),
   175  	}
   176  
   177  	ss := &stubserver.StubServer{
   178  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   179  			return &testpb.Empty{}, nil
   180  		},
   181  	}
   182  	if err := ss.Start(sopts); err != nil {
   183  		t.Fatalf("Error starting endpoint server: %v", err)
   184  	}
   185  	defer ss.Stop()
   186  
   187  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   188  	defer cancel()
   189  	resp, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
   190  	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
   191  		t.Fatalf("ss.Client.EmptyCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
   192  	}
   193  }
   194  
   195  func (s) TestChainStreamServerInterceptor(t *testing.T) {
   196  	callCounts := make([]int, 4)
   197  
   198  	firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   199  		if callCounts[0] != 0 {
   200  			return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0])
   201  		}
   202  		if callCounts[1] != 0 {
   203  			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
   204  		}
   205  		if callCounts[2] != 0 {
   206  			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   207  		}
   208  		if callCounts[3] != 0 {
   209  			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   210  		}
   211  		callCounts[0]++
   212  		return handler(srv, stream)
   213  	}
   214  
   215  	secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   216  		if callCounts[0] != 1 {
   217  			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
   218  		}
   219  		if callCounts[1] != 0 {
   220  			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
   221  		}
   222  		if callCounts[2] != 0 {
   223  			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   224  		}
   225  		if callCounts[3] != 0 {
   226  			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   227  		}
   228  		callCounts[1]++
   229  		return handler(srv, stream)
   230  	}
   231  
   232  	lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   233  		if callCounts[0] != 1 {
   234  			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
   235  		}
   236  		if callCounts[1] != 1 {
   237  			return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
   238  		}
   239  		if callCounts[2] != 0 {
   240  			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   241  		}
   242  		if callCounts[3] != 0 {
   243  			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   244  		}
   245  		callCounts[2]++
   246  		return handler(srv, stream)
   247  	}
   248  
   249  	sopts := []grpc.ServerOption{
   250  		grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt),
   251  	}
   252  
   253  	ss := &stubserver.StubServer{
   254  		FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
   255  			if callCounts[0] != 1 {
   256  				return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
   257  			}
   258  			if callCounts[1] != 1 {
   259  				return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
   260  			}
   261  			if callCounts[2] != 1 {
   262  				return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
   263  			}
   264  			if callCounts[3] != 0 {
   265  				return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
   266  			}
   267  			callCounts[3]++
   268  			return nil
   269  		},
   270  	}
   271  	if err := ss.Start(sopts); err != nil {
   272  		t.Fatalf("Error starting endpoint server: %v", err)
   273  	}
   274  	defer ss.Stop()
   275  
   276  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   277  	defer cancel()
   278  	stream, err := ss.Client.FullDuplexCall(ctx)
   279  	if err != nil {
   280  		t.Fatalf("failed to FullDuplexCall: %v", err)
   281  	}
   282  
   283  	_, err = stream.Recv()
   284  	if err != io.EOF {
   285  		t.Fatalf("failed to recv from stream: %v", err)
   286  	}
   287  
   288  	if callCounts[3] != 1 {
   289  		t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3])
   290  	}
   291  }